diff --git a/baselines/IPPO/ippo_rnn_mpe.py b/baselines/IPPO/ippo_rnn_mpe.py index 1a3ba809..93ad7ac5 100644 --- a/baselines/IPPO/ippo_rnn_mpe.py +++ b/baselines/IPPO/ippo_rnn_mpe.py @@ -57,7 +57,7 @@ class ActorCriticRNN(nn.Module): def __call__(self, hidden, x): obs, dones = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) @@ -74,7 +74,7 @@ def __call__(self, hidden, x): pi = distrax.Categorical(logits=actor_mean) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) diff --git a/baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml b/baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml index 68609e98..4c7e4eab 100644 --- a/baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml +++ b/baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml @@ -1,7 +1,9 @@ "LR": 5.0e-4 "NUM_ENVS": 128 -"NUM_STEPS": 128 # must be 128 +"NUM_STEPS": 128 "TOTAL_TIMESTEPS": 1e10 +"FC_DIM_SIZE": 128 +"GRU_HIDDEN_DIM": 128 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 diff --git a/baselines/MAPPO/config/mappo_homogenous_rnn_mpe.yaml b/baselines/MAPPO/config/mappo_homogenous_rnn_mpe.yaml index dd7db172..109a075a 100644 --- a/baselines/MAPPO/config/mappo_homogenous_rnn_mpe.yaml +++ b/baselines/MAPPO/config/mappo_homogenous_rnn_mpe.yaml @@ -1,7 +1,9 @@ "LR": 2e-3 "NUM_ENVS": 128 -"NUM_STEPS": 128 # must be 128 +"NUM_STEPS": 128 "TOTAL_TIMESTEPS": 2e6 +"FC_DIM_SIZE": 128 +"GRU_HIDDEN_DIM": 128 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 diff --git a/baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml b/baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml index d1b91b3f..9682dc8d 100644 --- a/baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml +++ b/baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml @@ -1,7 +1,9 @@ "LR": 0.002 "NUM_ENVS": 64 -"NUM_STEPS": 128 # must be 128 +"NUM_STEPS": 128 "TOTAL_TIMESTEPS": 1e7 +"FC_DIM_SIZE": 128 +"GRU_HIDDEN_DIM": 128 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 diff --git a/baselines/MAPPO/mappo_rnn_hanabi.py b/baselines/MAPPO/mappo_rnn_hanabi.py index e267a4f9..0ad55e0a 100644 --- a/baselines/MAPPO/mappo_rnn_hanabi.py +++ b/baselines/MAPPO/mappo_rnn_hanabi.py @@ -98,7 +98,7 @@ class ActorRNN(nn.Module): def __call__(self, hidden, x): obs, dones, avail_actions = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) @@ -121,12 +121,13 @@ def __call__(self, hidden, x): class CriticRNN(nn.Module): + config: Dict @nn.compact def __call__(self, hidden, x): world_state, dones = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(world_state) embedding = nn.relu(embedding) @@ -290,7 +291,7 @@ def _env_step(runner_state, unused): done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), - done_batch, + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), diff --git a/baselines/MAPPO/mappo_rnn_mpe.py b/baselines/MAPPO/mappo_rnn_mpe.py index 0dcd5261..bd41fd95 100644 --- a/baselines/MAPPO/mappo_rnn_mpe.py +++ b/baselines/MAPPO/mappo_rnn_mpe.py @@ -102,14 +102,14 @@ class ActorRNN(nn.Module): def __call__(self, hidden, x): obs, dones = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) actor_mean = nn.relu(actor_mean) @@ -123,19 +123,20 @@ def __call__(self, hidden, x): class CriticRNN(nn.Module): + config: Dict @nn.compact def __call__(self, hidden, x): world_state, dones = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(world_state) embedding = nn.relu(embedding) rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) @@ -192,7 +193,7 @@ def linear_schedule(count): def train(rng): # INIT NETWORK actor_network = ActorRNN(env.action_space(env.agents[0]).n, config=config) - critic_network = CriticRNN() + critic_network = CriticRNN(config=config) rng, _rng_actor, _rng_critic = jax.random.split(rng, 3) ac_init_x = ( jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).shape[0])), @@ -285,7 +286,7 @@ def _env_step(runner_state, unused): done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), - done_batch, + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), diff --git a/baselines/MAPPO/mappo_rnn_smax.py b/baselines/MAPPO/mappo_rnn_smax.py index e7026b6d..18aef63d 100644 --- a/baselines/MAPPO/mappo_rnn_smax.py +++ b/baselines/MAPPO/mappo_rnn_smax.py @@ -115,14 +115,14 @@ class ActorRNN(nn.Module): def __call__(self, hidden, x): obs, dones, avail_actions = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) actor_mean = nn.relu(actor_mean) @@ -138,19 +138,20 @@ def __call__(self, hidden, x): class CriticRNN(nn.Module): + config: Dict @nn.compact def __call__(self, hidden, x): world_state, dones = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(world_state) embedding = nn.relu(embedding) rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) @@ -214,7 +215,7 @@ def linear_schedule(count): def train(rng): # INIT NETWORK actor_network = ActorRNN(env.action_space(env.agents[0]).n, config=config) - critic_network = CriticRNN() + critic_network = CriticRNN(config=config) rng, _rng_actor, _rng_critic = jax.random.split(rng, 3) ac_init_x = ( jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).shape[0])), @@ -317,7 +318,7 @@ def _env_step(runner_state, unused): done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), - done_batch, + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),