Skip to content

Commit

Permalink
ugly fix
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 27, 2024
1 parent 7090e00 commit 3e6ddc6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions baselines/MAPPO/mappo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def world_state(self, obs, state):
hands = state.player_hands.reshape((self._env.num_agents, -1))
return jnp.concatenate((all_obs, hands), axis=1)


@partial(jax.jit, static_argnums=0)
def world_state_size(self):

return self._env.observation_space(self._env.agents[0]).n + 125 # NOTE hardcoded hand size
Expand Down Expand Up @@ -193,7 +193,7 @@ def linear_schedule(count):
def train(rng):
# INIT NETWORK
actor_network = ActorRNN(env.action_space(env.agents[0]).n, config=config)
critic_network = CriticRNN()
critic_network = CriticRNN(config=config)
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
ac_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).n)),
Expand All @@ -202,9 +202,9 @@ def train(rng):
)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)

print('ac init x',ac_init_x)
cr_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.world_state_size(),)),
jnp.zeros((1, config["NUM_ENVS"], 658+125,)), # NOTE hardcoded
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
Expand Down Expand Up @@ -272,6 +272,7 @@ def _env_step(runner_state, unused):
env_act = unbatchify(
action, env.agents, config["NUM_ENVS"], env.num_agents
)
env_act = jax.tree_map(lambda x: x.squeeze(), env_act)

# VALUE
world_state = last_obs["world_state"].reshape((config["NUM_ACTORS"],-1))
Expand Down

0 comments on commit 3e6ddc6

Please sign in to comment.