Skip to content

Commit

Permalink
Merge branch 'hanabi_obl_aligned' into hanabi_obl_pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ravihammond authored Mar 5, 2024
2 parents 768b9b2 + a0e312c commit 9dda54a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jaxmarl/environments/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def _legal_moves(aidx: int, state: State) -> chex.Array:
hands = state.player_hands
fireworks = state.fireworks
info_tokens = state.info_tokens
# discard always legal
#TODO: This incorrect, discard is only legal
# discard legal when discard tokens are not full
is_not_max_info_tokens = jnp.sum(state.info_tokens) < 8
legal_moves = legal_moves.at[move_idx:move_idx + self.hand_size].set(
is_not_max_info_tokens
Expand Down Expand Up @@ -165,7 +164,7 @@ def _get_hints_for_hand(carry, unused):

_, valid_hints = lax.scan(_get_hints_for_hand, (0, other_hands), None, self.num_agents - 1)
# make other player positions relative to current player
valid_hints = jnp.roll(valid_hints, aidx, axis=0)
valid_hints = jnp.roll(valid_hints, -aidx, axis=0)
# include valid hints in legal moves
num_hints = (self.num_agents - 1) * (self.num_colors + self.num_ranks)
valid_hints = jnp.concatenate(valid_hints, axis=0)
Expand Down

0 comments on commit 9dda54a

Please sign in to comment.