From c10177bc5ea59fb9a72b1a045ee262f84c061fa0 Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:45:02 +0000 Subject: [PATCH 01/13] Create docker-tests.yml --- .github/workflows/docker-tests.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .github/workflows/docker-tests.yml diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml new file mode 100644 index 00000000..8189a14b --- /dev/null +++ b/.github/workflows/docker-tests.yml @@ -0,0 +1,18 @@ +name: Docker Image CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build the Docker image + run: make build From c95e1a8d3a1c14e7dcf4c1b43dfa5bc58bd3b0fa Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:54:44 +0000 Subject: [PATCH 02/13] add calling of tests --- .github/workflows/docker-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 8189a14b..547bbc68 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -16,3 +16,5 @@ jobs: - uses: actions/checkout@v3 - name: Build the Docker image run: make build + - name: Run tests + run: make test From c3fd12f5c5eb167f6498e20d7ba3f84724c48339 Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 15:04:23 +0000 Subject: [PATCH 03/13] Update Makefile for docker workflow --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 3ebd7fcb..c54b1f98 100644 --- a/Makefile +++ b/Makefile @@ -26,3 +26,6 @@ run: test: $(DOCKER_RUN) /bin/bash -c "pytest ./tests/" +workflow-test: + # without -it flag + docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/" From afe2cfb6a90de0ffae78f96ea0f4ce1683e1458e Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 15:04:53 +0000 Subject: [PATCH 04/13] Update docker-tests.yml --- .github/workflows/docker-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 547bbc68..c6ee9450 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -17,4 +17,4 @@ jobs: - name: Build the Docker image run: make build - name: Run tests - run: make test + run: make workflow-test From 7f3eb840a94736637e404eaf5360b78185533162 Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 15:13:31 +0000 Subject: [PATCH 05/13] Update docker-tests.yml --- .github/workflows/docker-tests.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index c6ee9450..7cd69f04 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -1,10 +1,5 @@ -name: Docker Image CI - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] +name: Run tests on Docker Image +on: [push, pull_request] jobs: From e47fad1ac657cd8e989b1655f397e80dc00a00cf Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 15:38:00 +0000 Subject: [PATCH 06/13] remove explicit jax req for docker file --- requirements/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e6009f3b..6d3b455e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ # requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax==0.4.17 -jaxlib==0.4.17 +jax +jaxlib flax==0.7.4 chex==0.1.84 optax==0.1.7 From 5b215f477cbbdc383f67b085671f31650e8e1787 Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 15:39:19 +0000 Subject: [PATCH 07/13] Delete .github/workflows/tests.yml --- .github/workflows/tests.yml | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 0fc150c9..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Tests -on: [push, pull_request] - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: true - max-parallel: 15 - matrix: - # os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge] - # For Apple Silicon: https://github.com/actions/runner-images/issues/8439 - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.10'] - defaults: - run: - shell: bash - steps: - - name: Check out repository - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - - - name: Run pytest - run: pytest tests From 62c1b42354fed0423aa666ad1539985d8489234f Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 16:15:02 +0000 Subject: [PATCH 08/13] put explicit back in :) --- requirements/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6d3b455e..19427146 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ # requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax -jaxlib +jax==0.4.17.* +jaxlib==0.4.17.* flax==0.7.4 chex==0.1.84 optax==0.1.7 From 7811f12c9a6f21ee2e04be710a42ac08aebcf5be Mon Sep 17 00:00:00 2001 From: Alex Rutherford <66562395+amacrutherford@users.noreply.github.com> Date: Fri, 29 Mar 2024 16:27:38 +0000 Subject: [PATCH 09/13] Update jupyter --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3a3eab14..4109f5e5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,10 +13,10 @@ RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false RUN export XLA_PYTHON_CLIENT_MEM_FRACTION=0.25 RUN export TF_FORCE_GPU_ALLOW_GROWTH=true -# if you want jupyter -RUN pip install pip install jupyterlab +# Uncomment below if you want jupyter +# RUN pip install jupyterlab #for secrets and debug ENV WANDB_API_KEY="" ENV WANDB_ENTITY="" -RUN git config --global --add safe.directory /home/workdir \ No newline at end of file +RUN git config --global --add safe.directory /home/workdir From d79d9ef6980f7896bf2678b13acf706e479b7634 Mon Sep 17 00:00:00 2001 From: collinfeng Date: Sun, 31 Mar 2024 18:55:44 +0000 Subject: [PATCH 10/13] feature_level_sampling --- jaxmarl/environments/hanabi/hint_guess.py | 72 ++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/jaxmarl/environments/hanabi/hint_guess.py b/jaxmarl/environments/hanabi/hint_guess.py index 49c26e36..6645fbc3 100644 --- a/jaxmarl/environments/hanabi/hint_guess.py +++ b/jaxmarl/environments/hanabi/hint_guess.py @@ -7,6 +7,7 @@ from flax import struct from jaxmarl.environments.multi_agent_env import MultiAgentEnv from gymnax.environments.spaces import Discrete +import copy @struct.dataclass @@ -46,6 +47,7 @@ def __init__( self.num_classes_per_feature = num_classes_per_feature self.num_cards = np.prod(self.num_classes_per_feature) self.matrix_obs = matrix_obs + self.feature_tree = [np.arange(n_c) for n_c in num_classes_per_feature] # generate the deck of one-hot encoded cards if card_encoding == "onehot": @@ -76,7 +78,7 @@ def reset(self, rng): self.hand_size, ), ) - + # every agent sees the hands in different order _rngs = jax.random.split(rng_hands, self.num_agents) permuted_hands = jax.vmap( @@ -91,6 +93,67 @@ def reset(self, rng): ) return jax.lax.stop_gradient(self.get_obs(state)), state + @partial(jax.jit, static_argnums=[0]) + def reset_for_eval(self, rng): + + target_rng, hint_rng, hinter_hand_rng, guesser_hand_rng = jax.random.split(rng, 4) + + def feature_level_sample(rng, subarray): + return jax.random.choice(rng, subarray, shape=(1,)) + + def card_level_sample(rng, feature_tree): + rng = jax.random.split(rng, self.num_features) + return jnp.array([feature_level_sample(rng, subarray) for rng, subarray in zip(rng, feature_tree)]) + + def remove_feature(feature, subarray): + return jnp.delete(subarray, feature) + + def exact_match(rngs): + _, hinter_hand_rngs, guesser_hand_rngs = rngs + set_of_rest_of_hinter_and_guesser_hand = [remove_feature(feature, subarray) for feature, subarray in zip(target, self.feature_tree)] + hint = target + rest_of_hinter_hand = hand_level_sample(hinter_hand_rngs, set_of_rest_of_hinter_and_guesser_hand) + rest_of_guesser_hand = hand_level_sample(guesser_hand_rngs, set_of_rest_of_hinter_and_guesser_hand) + return hint, rest_of_hinter_hand, rest_of_guesser_hand + + def similar_match(rngs): + hint_rng, hinter_hand_rngs, guesser_hand_rngs = rngs + feature_at_interest = jax.random.choice(rng, self.num_features) + set_of_hints = copy.deepcopy(self.feature_tree) + set_of_hints[feature_at_interest] = [target[feature_at_interest]] + set_of_rest_of_hinter_hand = copy.deepcopy(self.feature_tree) + set_of_rest_of_hinter_hand = [remove_feature(feature, subarray) for feature, subarray in zip(target, self.feature_tree)] + set_of_rest_of_guesser_hand = set_of_rest_of_hinter_hand + + hint = card_level_sample(hint_rng, set_of_hints) + rest_of_hinter_hand = hand_level_sample(hinter_hand_rngs, set_of_rest_of_hinter_hand) + rest_of_guesser_hand = hand_level_sample(guesser_hand_rngs, set_of_rest_of_guesser_hand) + return hint, rest_of_hinter_hand, rest_of_guesser_hand + + def mutual_exclusive(rngs): + hint_rng, hinter_hand_rngs, guesser_hand_rngs = rngs + + hand_level_sample = jax.vmap(card_level_sample, in_axes=(0, None)) + replace = False + + target = card_level_sample(target_rng, self.feature_tree) + hinter_hand_rngs = jax.random.split(hinter_hand_rng, self.hand_size) + guesser_hand_rngs = jax.random.split(guesser_hand_rng, self.hand_size) + rngs = [hint_rng, hinter_hand_rngs, guesser_hand_rngs] + + print(target, exact_match(rngs)) + + # card_set = jnp.arange(self.num_cards) + # target_flat_id = jax.random.choice(rng, self.num_cards) + # target_multi_id = jnp.unravel_index(target_flat_id, self.num_classes_per_feature) + + + + + + # eval_fns = [exact_match, similar_match, mutual_exclusive, exclusive_match] + # return jax.lax.switch(eval_mode, eval_fns, rng) + @partial(jax.jit, static_argnums=[0]) def step_env(self, rng, state, actions): @@ -210,3 +273,10 @@ def get_onehot_encodings(self): [jnp.concatenate(combination) for combination in list(product(*encodings))] ) return encodings + + +if __name__ == "__main__": + jax.config.update("jax_disable_jit", True) + env = HintGuessGame() + rng = jax.random.PRNGKey(0) + env.reset_for_eval(rng) From bd4fa012b2195438055263d03b2dd1b7ee60c1d5 Mon Sep 17 00:00:00 2001 From: collinfeng Date: Mon, 1 Apr 2024 14:26:45 +0000 Subject: [PATCH 11/13] ready_for_matteo's review, also need help on JIT --- jaxmarl/environments/hanabi/hint_guess.py | 188 ++++++++++++++++------ 1 file changed, 136 insertions(+), 52 deletions(-) diff --git a/jaxmarl/environments/hanabi/hint_guess.py b/jaxmarl/environments/hanabi/hint_guess.py index 6645fbc3..c848adfa 100644 --- a/jaxmarl/environments/hanabi/hint_guess.py +++ b/jaxmarl/environments/hanabi/hint_guess.py @@ -93,66 +93,149 @@ def reset(self, rng): ) return jax.lax.stop_gradient(self.get_obs(state)), state - @partial(jax.jit, static_argnums=[0]) - def reset_for_eval(self, rng): - - target_rng, hint_rng, hinter_hand_rng, guesser_hand_rng = jax.random.split(rng, 4) - - def feature_level_sample(rng, subarray): - return jax.random.choice(rng, subarray, shape=(1,)) - - def card_level_sample(rng, feature_tree): - rng = jax.random.split(rng, self.num_features) - return jnp.array([feature_level_sample(rng, subarray) for rng, subarray in zip(rng, feature_tree)]) - - def remove_feature(feature, subarray): - return jnp.delete(subarray, feature) + @partial(jax.jit, static_argnums=[0, 2]) + def reset_for_eval(self, rng, reset_mode="exact_match"): - def exact_match(rngs): - _, hinter_hand_rngs, guesser_hand_rngs = rngs - set_of_rest_of_hinter_and_guesser_hand = [remove_feature(feature, subarray) for feature, subarray in zip(target, self.feature_tree)] - hint = target - rest_of_hinter_hand = hand_level_sample(hinter_hand_rngs, set_of_rest_of_hinter_and_guesser_hand) - rest_of_guesser_hand = hand_level_sample(guesser_hand_rngs, set_of_rest_of_hinter_and_guesser_hand) - return hint, rest_of_hinter_hand, rest_of_guesser_hand + def exact_match(card_multi_set): + card_flat_set = card_multi_set.flatten() + hint_flat_id = target_flat_id + hinter_and_guesser_flat_hand_set = jnp.delete(card_flat_set, target_flat_id, assume_unique_indices=True) + hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + hinter_and_guesser_flat_hand_set, + shape=(self.hand_size-1,)) + guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, + hinter_and_guesser_flat_hand_set, + shape=(self.hand_size-1,)) + return hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand - def similar_match(rngs): - hint_rng, hinter_hand_rngs, guesser_hand_rngs = rngs - feature_at_interest = jax.random.choice(rng, self.num_features) - set_of_hints = copy.deepcopy(self.feature_tree) - set_of_hints[feature_at_interest] = [target[feature_at_interest]] - set_of_rest_of_hinter_hand = copy.deepcopy(self.feature_tree) - set_of_rest_of_hinter_hand = [remove_feature(feature, subarray) for feature, subarray in zip(target, self.feature_tree)] - set_of_rest_of_guesser_hand = set_of_rest_of_hinter_hand - - hint = card_level_sample(hint_rng, set_of_hints) - rest_of_hinter_hand = hand_level_sample(hinter_hand_rngs, set_of_rest_of_hinter_hand) - rest_of_guesser_hand = hand_level_sample(guesser_hand_rngs, set_of_rest_of_guesser_hand) - return hint, rest_of_hinter_hand, rest_of_guesser_hand - - def mutual_exclusive(rngs): - hint_rng, hinter_hand_rngs, guesser_hand_rngs = rngs - - hand_level_sample = jax.vmap(card_level_sample, in_axes=(0, None)) - replace = False - - target = card_level_sample(target_rng, self.feature_tree) - hinter_hand_rngs = jax.random.split(hinter_hand_rng, self.hand_size) - guesser_hand_rngs = jax.random.split(guesser_hand_rng, self.hand_size) - rngs = [hint_rng, hinter_hand_rngs, guesser_hand_rngs] + def similarity_match(card_multi_set): + feature_of_interest = jax.random.choice(hint_rng, self.num_features) + target_index_of_interest = target_multi_id[feature_of_interest] + # note this hint_set also include target card, need to be removed + # print(feature_of_interest, target_index_of_interest) + hint_set = jax.lax.dynamic_index_in_dim(card_multi_set, + target_index_of_interest, + feature_of_interest, + keepdims=False) + # find the target card id in hint set after slicing from the feature of interest + target_id = jnp.concatenate((target_multi_id[:feature_of_interest], target_multi_id[feature_of_interest+1:])) + hint_flat_set = jnp.delete(hint_set, target_id, assume_unique_indices=True).flatten() + + non_similar_hand_set = copy.deepcopy(card_multi_set) + for feature_dim in range(self.num_features): + non_similar_hand_set = jnp.delete(non_similar_hand_set, + target_multi_id[feature_dim], + axis=feature_dim, + assume_unique_indices=True) + + non_similar_flat_hand_set = non_similar_hand_set.flatten() + hinter_flat_id = jax.random.choice(hint_rng, + hint_flat_set, + shape=(1,)) + hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + non_similar_flat_hand_set, + shape=(self.hand_size-1,)) + guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, + non_similar_flat_hand_set, + shape=(self.hand_size-1,)) + return hinter_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand - print(target, exact_match(rngs)) + def mutual_exclusive(card_multi_set): + non_similar_hand_set = copy.deepcopy(card_multi_set) + for feature_dim in range(self.num_features): + non_similar_hand_set = jnp.delete(non_similar_hand_set, + target_multi_id[feature_dim], + axis=feature_dim, + assume_unique_indices=True) + + non_similar_flat_hand_set = non_similar_hand_set.flatten() + hint_flat_id = jax.random.choice(hint_rng, + non_similar_flat_hand_set, + shape=(1,)) + hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(non_similar_flat_hand_set, hint_flat_id, assume_unique_indices=True) + # note the rest of hand of both players are the same, so use either of the rngs + hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + hinter_and_guesser_flat_rest_of_hand_set, + shape=(self.hand_size-1,)) + return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand - # card_set = jnp.arange(self.num_cards) - # target_flat_id = jax.random.choice(rng, self.num_cards) - # target_multi_id = jnp.unravel_index(target_flat_id, self.num_classes_per_feature) + def mutual_exclusice_similarity(card_multi_set): + # the target will be included by the first slice, thus need to be removed + similar_cards_of_the_first_feature = jax.lax.dynamic_index_in_dim(card_multi_set, + target_multi_id[0], + 0, + keepdims=False) + target_id = target_multi_id[1:] + hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(similar_cards_of_the_first_feature, target_id, assume_unique_indices=True).flatten() + + # later slices does include the target card + for feature_dim in range(1, self.num_features): + similar_cards = jax.lax.dynamic_index_in_dim(card_multi_set, + target_multi_id[feature_dim], + feature_dim, + keepdims=False) + hinter_and_guesser_flat_rest_of_hand_set = jnp.append(hinter_and_guesser_flat_rest_of_hand_set, + similar_cards.flatten()) + card_multi_set = jnp.delete(card_multi_set, + target_multi_id[feature_dim], + axis=feature_dim, + assume_unique_indices=True) + + hint_flat_set = card_multi_set.flatten() + hint_flat_id = jax.random.choice(hint_rng, hint_flat_set, shape=(1,)) + # note the rest of hand of both players are the same, so use either of the rngs + hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + hinter_and_guesser_flat_rest_of_hand_set, + shape=(self.hand_size-1,)) + return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand + + def shuffle_and_index(rng, players_hands): + def set_single_hand(hand, index): + empty_hands = jnp.zeros(5, dtype=jnp.int32) + return empty_hands.at[index].set(hand) + """ + generates a permutation mapping for the hands of the players such that the target_card and hint_card are tractable after the permutation + returns permuted hands, hint_card_index and target_card_index in the permuted hands + """ + rngs = jax.random.split(rng, 2) + permutation_index = jax.vmap(jax.random.permutation, in_axes=(0, None))(rngs, 5) + permuted_hands = jax.vmap(set_single_hand, in_axes=(0, 0))(players_hands, permutation_index) + return permuted_hands, permutation_index[0, 0], permutation_index[0, 1] + + target_rng, hint_rng, hinter_hand_rngs, guesser_hand_rngs = jax.random.split(rng, 4) + # constants + target_flat_id = jax.random.choice(target_rng, self.num_cards) + target_multi_id = jnp.array(jnp.unravel_index(target_flat_id, self.num_classes_per_feature)) + card_multi_set = jnp.arange(self.num_cards).reshape(self.num_classes_per_feature) + + if reset_mode == "exact_match": + hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = exact_match(card_multi_set) + elif reset_mode == "similarity_match": + hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = similarity_match(card_multi_set) + elif reset_mode == "mutual_exclusive": + hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusive(card_multi_set) + elif reset_mode == "mutual_exclusive_similarity": + hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusice_similarity(card_multi_set) + else: + raise ValueError("reset_mode is not supported") + hinter_hand = jnp.append(hint_flat_id, hinter_flat_rest_of_hand) + guesser_hand = jnp.append(target_flat_id, guesser_flat_rest_of_hand) + + player_hands = jnp.stack((hinter_hand, guesser_hand)) + print(player_hands.shape) + rngs = jnp.stack((hinter_hand_rngs, guesser_hand_rngs)) + permuted_hands, hints, targets = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands) + state = State( + player_hands=permuted_hands, target=target_flat_id, hint=-1, guess=-1, turn=0 + ) + + return jax.lax.stop_gradient(self.get_obs(state)), state, hints, targets + - # eval_fns = [exact_match, similar_match, mutual_exclusive, exclusive_match] - # return jax.lax.switch(eval_mode, eval_fns, rng) @partial(jax.jit, static_argnums=[0]) def step_env(self, rng, state, actions): @@ -279,4 +362,5 @@ def get_onehot_encodings(self): jax.config.update("jax_disable_jit", True) env = HintGuessGame() rng = jax.random.PRNGKey(0) - env.reset_for_eval(rng) + _, state, _, _ = env.reset_for_eval(rng, reset_mode="exact_match") + print(state) From feb82dca8b0b13d52d8cd033504240da2a4e00ec Mon Sep 17 00:00:00 2001 From: collinfeng Date: Mon, 1 Apr 2024 19:24:03 +0000 Subject: [PATCH 12/13] fixed_jit_issue for eval cases generation --- jaxmarl/environments/hanabi/hint_guess.py | 219 +++++++++++----------- 1 file changed, 105 insertions(+), 114 deletions(-) diff --git a/jaxmarl/environments/hanabi/hint_guess.py b/jaxmarl/environments/hanabi/hint_guess.py index c848adfa..dceebc94 100644 --- a/jaxmarl/environments/hanabi/hint_guess.py +++ b/jaxmarl/environments/hanabi/hint_guess.py @@ -47,7 +47,7 @@ def __init__( self.num_classes_per_feature = num_classes_per_feature self.num_cards = np.prod(self.num_classes_per_feature) self.matrix_obs = matrix_obs - self.feature_tree = [np.arange(n_c) for n_c in num_classes_per_feature] + self.card_feature_space = jnp.array(list(product(*[np.arange(n_c) for n_c in self.num_classes_per_feature]))) # generate the deck of one-hot encoded cards if card_encoding == "onehot": @@ -93,101 +93,39 @@ def reset(self, rng): ) return jax.lax.stop_gradient(self.get_obs(state)), state - @partial(jax.jit, static_argnums=[0, 2]) - def reset_for_eval(self, rng, reset_mode="exact_match"): + @partial(jax.jit, static_argnums=[0, 2, 3]) + def reset_for_eval(self, rng, reset_mode="exact_match", replace=True): - def exact_match(card_multi_set): - card_flat_set = card_multi_set.flatten() - hint_flat_id = target_flat_id - hinter_and_guesser_flat_hand_set = jnp.delete(card_flat_set, target_flat_id, assume_unique_indices=True) - hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_hand_set, - shape=(self.hand_size-1,)) - guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, - hinter_and_guesser_flat_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand + def p_exact_match(masks): + target_mask, non_target_mask, _, _, _ = masks + p_hint = target_mask/jnp.sum(target_mask) + p_hinter_and_guesser_rest_of_hand = non_target_mask/jnp.sum(non_target_mask) + return p_hint, p_hinter_and_guesser_rest_of_hand - def similarity_match(card_multi_set): - feature_of_interest = jax.random.choice(hint_rng, self.num_features) - target_index_of_interest = target_multi_id[feature_of_interest] - # note this hint_set also include target card, need to be removed - # print(feature_of_interest, target_index_of_interest) - hint_set = jax.lax.dynamic_index_in_dim(card_multi_set, - target_index_of_interest, - feature_of_interest, - keepdims=False) - # find the target card id in hint set after slicing from the feature of interest - target_id = jnp.concatenate((target_multi_id[:feature_of_interest], target_multi_id[feature_of_interest+1:])) - hint_flat_set = jnp.delete(hint_set, target_id, assume_unique_indices=True).flatten() - - non_similar_hand_set = copy.deepcopy(card_multi_set) - for feature_dim in range(self.num_features): - non_similar_hand_set = jnp.delete(non_similar_hand_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - non_similar_flat_hand_set = non_similar_hand_set.flatten() - hinter_flat_id = jax.random.choice(hint_rng, - hint_flat_set, - shape=(1,)) - hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - non_similar_flat_hand_set, - shape=(self.hand_size-1,)) - guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, - non_similar_flat_hand_set, - shape=(self.hand_size-1,)) - return hinter_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand + def p_similarity_match(masks): + _, _, random_similar_feature_exclude_target_mask, _, non_similar_feature_mask = masks + hint_p = random_similar_feature_exclude_target_mask/jnp.sum(random_similar_feature_exclude_target_mask) + p_hinter_and_guesser_rest_of_hand = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) + return hint_p, p_hinter_and_guesser_rest_of_hand - def mutual_exclusive(card_multi_set): - non_similar_hand_set = copy.deepcopy(card_multi_set) - for feature_dim in range(self.num_features): - non_similar_hand_set = jnp.delete(non_similar_hand_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - non_similar_flat_hand_set = non_similar_hand_set.flatten() + def p_mutual_exclusive(masks): + _, _, _, _, non_similar_feature_mask = masks + p_non_sim = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) hint_flat_id = jax.random.choice(hint_rng, - non_similar_flat_hand_set, - shape=(1,)) - hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(non_similar_flat_hand_set, hint_flat_id, assume_unique_indices=True) - # note the rest of hand of both players are the same, so use either of the rngs - hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_rest_of_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand + card_space, + shape=(1,), + p=p_non_sim) + hint_mask = jax.nn.one_hot(x=hint_flat_id, num_classes=self.num_cards).flatten() # note this is also p_hint, as the chosen card has p=1 + hinter_and_guesser_rest_of_hand_mask = jnp.logical_and(non_similar_feature_mask, jnp.logical_not(hint_mask)) + p_hinter_and_guesser_rest_of_hand = hinter_and_guesser_rest_of_hand_mask/jnp.sum(hinter_and_guesser_rest_of_hand_mask) + return hint_mask, p_hinter_and_guesser_rest_of_hand - def mutual_exclusice_similarity(card_multi_set): - # the target will be included by the first slice, thus need to be removed - similar_cards_of_the_first_feature = jax.lax.dynamic_index_in_dim(card_multi_set, - target_multi_id[0], - 0, - keepdims=False) - target_id = target_multi_id[1:] - hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(similar_cards_of_the_first_feature, target_id, assume_unique_indices=True).flatten() - - # later slices does include the target card - for feature_dim in range(1, self.num_features): - similar_cards = jax.lax.dynamic_index_in_dim(card_multi_set, - target_multi_id[feature_dim], - feature_dim, - keepdims=False) - hinter_and_guesser_flat_rest_of_hand_set = jnp.append(hinter_and_guesser_flat_rest_of_hand_set, - similar_cards.flatten()) - card_multi_set = jnp.delete(card_multi_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - hint_flat_set = card_multi_set.flatten() - hint_flat_id = jax.random.choice(hint_rng, hint_flat_set, shape=(1,)) - # note the rest of hand of both players are the same, so use either of the rngs - hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_rest_of_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand + def p_mutual_exclusice_similarity(masks): + _, _, _, similar_feature_exclude_target_mask, non_similar_feature_mask = masks + p_hint = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) + p_hinter_and_guesser_rest_of_hand = similar_feature_exclude_target_mask/jnp.sum(similar_feature_exclude_target_mask) + print(similar_feature_exclude_target_mask, p_hinter_and_guesser_rest_of_hand) + return p_hint, p_hinter_and_guesser_rest_of_hand def shuffle_and_index(rng, players_hands): def set_single_hand(hand, index): @@ -195,47 +133,96 @@ def set_single_hand(hand, index): return empty_hands.at[index].set(hand) """ generates a permutation mapping for the hands of the players such that the target_card and hint_card are tractable after the permutation - returns permuted hands, hint_card_index and target_card_index in the permuted hands + returns permuted hands, hint_card_index and target_card_index in the permuted hands of hinter and guesser """ rngs = jax.random.split(rng, 2) permutation_index = jax.vmap(jax.random.permutation, in_axes=(0, None))(rngs, 5) permuted_hands = jax.vmap(set_single_hand, in_axes=(0, 0))(players_hands, permutation_index) - return permuted_hands, permutation_index[0, 0], permutation_index[0, 1] + return permuted_hands, permutation_index[0, 0], permutation_index[1, 0] target_rng, hint_rng, hinter_hand_rngs, guesser_hand_rngs = jax.random.split(rng, 4) - # constants + # target randomisation target_flat_id = jax.random.choice(target_rng, self.num_cards) target_multi_id = jnp.array(jnp.unravel_index(target_flat_id, self.num_classes_per_feature)) - card_multi_set = jnp.arange(self.num_cards).reshape(self.num_classes_per_feature) - - if reset_mode == "exact_match": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = exact_match(card_multi_set) - elif reset_mode == "similarity_match": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = similarity_match(card_multi_set) - elif reset_mode == "mutual_exclusive": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusive(card_multi_set) - elif reset_mode == "mutual_exclusive_similarity": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusice_similarity(card_multi_set) + #copy card space to ensure env is not modified + card_space = jnp.arange(self.num_cards) + card_feature_space = self.card_feature_space + + # generate mask for exact match and non_exact match + target_mask = jnp.where(target_flat_id + == card_space, + 1, + 0).flatten() + non_target_mask = 1 - target_mask + + # generate mask for similar cards for a randomly selected feature + random_feature_of_interest = jax.random.choice(hint_rng, self.num_features) + random_similar_feature_mask = jnp.where(card_feature_space[:, random_feature_of_interest] + == target_multi_id[random_feature_of_interest], + 1, + 0).flatten() + random_similar_feature_exclude_target_mask = non_target_mask * random_similar_feature_mask + + # generate mask for all non-similar cards for all features + similar_feature_mask = jnp.zeros(self.num_cards) + non_similar_feature_mask = jnp.ones(self.num_cards) + for feature_dim in range(self.num_features): + # + is logical or operation, * is logical and operation + similar_feature_mask = similar_feature_mask + jnp.where(card_feature_space[:, feature_dim] + == target_multi_id[feature_dim], + 1, + 0).flatten() + non_similar_feature_mask = non_similar_feature_mask * jnp.where(card_feature_space[:, feature_dim] + != target_multi_id[feature_dim], + 1, + 0).flatten() + similar_feature_mask_exculde_target = similar_feature_mask * non_target_mask + + masks = (target_mask, non_target_mask, random_similar_feature_exclude_target_mask, similar_feature_mask_exculde_target, non_similar_feature_mask) + p_reset_modes = { + "exact_match": p_exact_match, + "similarity_match": p_similarity_match, + "mutual_exclusive": p_mutual_exclusive, + "mutual_exclusive_similarity": p_mutual_exclusice_similarity, + } + if reset_mode in p_reset_modes: + p_hint, p_other = p_reset_modes[reset_mode](masks) else: raise ValueError("reset_mode is not supported") + + hinter_flat_id = jax.random.choice(hint_rng, + card_space, + shape=(1,), + replace=replace, + p=p_hint) + + hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + card_space, + shape=(self.hand_size-1,), + replace=replace, + p=p_other) + if reset_mode == "mutual_exclusive" or reset_mode == "mutual_exclusive_similarity": + guesser_flat_rest_of_hand = hinter_flat_rest_of_hand + else: + guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, + card_space, + shape=(self.hand_size-1,), + replace=replace, + p=p_other) - hinter_hand = jnp.append(hint_flat_id, hinter_flat_rest_of_hand) + hinter_hand = jnp.append(hinter_flat_id, hinter_flat_rest_of_hand) guesser_hand = jnp.append(target_flat_id, guesser_flat_rest_of_hand) player_hands = jnp.stack((hinter_hand, guesser_hand)) - print(player_hands.shape) rngs = jnp.stack((hinter_hand_rngs, guesser_hand_rngs)) - permuted_hands, hints, targets = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands) + permuted_hands, hint_indices, target_indices = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands) state = State( player_hands=permuted_hands, target=target_flat_id, hint=-1, guess=-1, turn=0 ) - return jax.lax.stop_gradient(self.get_obs(state)), state, hints, targets - - - + return jax.lax.stop_gradient(self.get_obs(state)), state, hint_indices, target_indices @partial(jax.jit, static_argnums=[0]) def step_env(self, rng, state, actions): @@ -359,8 +346,12 @@ def get_onehot_encodings(self): if __name__ == "__main__": - jax.config.update("jax_disable_jit", True) + # jax.config.update("jax_disable_jit", True) env = HintGuessGame() - rng = jax.random.PRNGKey(0) - _, state, _, _ = env.reset_for_eval(rng, reset_mode="exact_match") + rng = jax.random.PRNGKey(10) + # reset_modes: exact_match, similarity_match, mutual_exclusive, mutual_exclusive_similarity + _, state, hints, targets = env.reset_for_eval(rng, reset_mode="similarity_match", replace=True) + print(jnp.arange(9).reshape(3, 3)) print(state) + print(hints) + print(targets) From 0c42d25475d6f62face32ec309eee2e60ad9530e Mon Sep 17 00:00:00 2001 From: collinfeng Date: Mon, 1 Apr 2024 19:32:44 +0000 Subject: [PATCH 13/13] personal_tests --- .gitignore | 2 ++ Makefile | 2 +- baselines/IPPO/config/ippo_ff_hint_guess.yaml | 4 ++-- baselines/IPPO/ippo_ff_hint_guess.py | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 17a7d47a..59ad1f64 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ tmp/ wandb/ outputs/ models/ +.devcontainer/ +.gitignore diff --git a/Makefile b/Makefile index 3ebd7fcb..0267f49a 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ endif BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G RUN_FLAGS=$(GPUS) $(BASE_FLAGS) -DOCKER_IMAGE_NAME = jaxmarl +DOCKER_IMAGE_NAME = jaxmarl-cf IMAGE = $(DOCKER_IMAGE_NAME):latest DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE) USE_CUDA = $(if $(GPUS),true,false) diff --git a/baselines/IPPO/config/ippo_ff_hint_guess.yaml b/baselines/IPPO/config/ippo_ff_hint_guess.yaml index f4ec6c62..03814267 100644 --- a/baselines/IPPO/config/ippo_ff_hint_guess.yaml +++ b/baselines/IPPO/config/ippo_ff_hint_guess.yaml @@ -21,5 +21,5 @@ # WandB Params "WANDB_MODE": "online" -"ENTITY": "mttga" -"PROJECT": "hint_guess" \ No newline at end of file +"ENTITY": "clf26" +"PROJECT": "action-feature" \ No newline at end of file diff --git a/baselines/IPPO/ippo_ff_hint_guess.py b/baselines/IPPO/ippo_ff_hint_guess.py index 8b2ea67c..8ce215eb 100644 --- a/baselines/IPPO/ippo_ff_hint_guess.py +++ b/baselines/IPPO/ippo_ff_hint_guess.py @@ -423,8 +423,8 @@ def wrapped_make_train(): @hydra.main(version_base=None, config_path="config", config_name="ippo_ff_hint_guess") def main(config): config = OmegaConf.to_container(config) - #single_run(config) - tune(config) + single_run(config) + # tune(config) if __name__ == "__main__":