Skip to content

Commit

Permalink
ready for test
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 15, 2024
1 parent bed7b60 commit 475ead0
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions jaxmarl/environments/hanabi/obl/obl_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(self, weight_file, device="cuda:0"):
'c0':torch.zeros(1, 2, 2, 512).to(device),
}

def act(self, env, obs, legal_moves, curr_player):
obs = self._batchify(env, obs)
legal_moves = self._batchify(env, legal_moves)
legal_moves = jnp.roll(legal_moves, -1, axis=1)
def act(self, obs, legal_moves, curr_player):
obs = self._batchify(obs)
legal_moves = self._batchify(legal_moves)
#legal_moves = jnp.roll(legal_moves, -1, axis=1)

torch_obs = {
'priv_s':torch.tensor(np.array(obs[..., 125:])).to(self._device),
Expand All @@ -31,14 +31,15 @@ def act(self, env, obs, legal_moves, curr_player):
'c0': self._hid['c0'].to(self._device),
'legal_move': torch.tensor(np.array(legal_moves)).to(self._device),
}

act_result = self._agent.act(torch_obs)
actions = act_result.pop('a').detach().numpy()
self._hid = act_result
actions = np.where(actions+1==21, 0, actions+1)
#actions = np.where(actions+1==21, 0, actions+1)
return actions

def _batchify(self, env, x):
return jnp.stack([x[a] for a in env.agents])
def _batchify(self, x_dict):
return jnp.stack([x_dict[agent] for agent in sorted(x_dict)])


def load_agent_from_file(weight_file, device, sad_legacy=False, iql_legacy=False):
Expand Down

0 comments on commit 475ead0

Please sign in to comment.