Skip to content

Commit

Permalink
done fix for mappo
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 27, 2024
1 parent a8fb4ca commit 0863270
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 20 deletions.
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ActorCriticRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

Expand All @@ -74,7 +74,7 @@ def __call__(self, hidden, x):

pi = distrax.Categorical(logits=actor_mean)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand Down
4 changes: 3 additions & 1 deletion baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"LR": 5.0e-4
"NUM_ENVS": 128
"NUM_STEPS": 128 # must be 128
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e10
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
Expand Down
4 changes: 3 additions & 1 deletion baselines/MAPPO/config/mappo_homogenous_rnn_mpe.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"LR": 2e-3
"NUM_ENVS": 128
"NUM_STEPS": 128 # must be 128
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 2e6
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
Expand Down
4 changes: 3 additions & 1 deletion baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"LR": 0.002
"NUM_ENVS": 64
"NUM_STEPS": 128 # must be 128
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e7
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
Expand Down
7 changes: 4 additions & 3 deletions baselines/MAPPO/mappo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class ActorRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones, avail_actions = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

Expand All @@ -121,12 +121,13 @@ def __call__(self, hidden, x):


class CriticRNN(nn.Module):
config: Dict

@nn.compact
def __call__(self, hidden, x):
world_state, dones = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(world_state)
embedding = nn.relu(embedding)

Expand Down Expand Up @@ -290,7 +291,7 @@ def _env_step(runner_state, unused):
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
done_batch,
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down
13 changes: 7 additions & 6 deletions baselines/MAPPO/mappo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ class ActorRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand All @@ -123,19 +123,20 @@ def __call__(self, hidden, x):


class CriticRNN(nn.Module):
config: Dict

@nn.compact
def __call__(self, hidden, x):
world_state, dones = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(world_state)
embedding = nn.relu(embedding)

rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand Down Expand Up @@ -192,7 +193,7 @@ def linear_schedule(count):
def train(rng):
# INIT NETWORK
actor_network = ActorRNN(env.action_space(env.agents[0]).n, config=config)
critic_network = CriticRNN()
critic_network = CriticRNN(config=config)
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
ac_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).shape[0])),
Expand Down Expand Up @@ -285,7 +286,7 @@ def _env_step(runner_state, unused):
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
done_batch,
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down
13 changes: 7 additions & 6 deletions baselines/MAPPO/mappo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ class ActorRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones, avail_actions = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand All @@ -138,19 +138,20 @@ def __call__(self, hidden, x):


class CriticRNN(nn.Module):
config: Dict

@nn.compact
def __call__(self, hidden, x):
world_state, dones = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(world_state)
embedding = nn.relu(embedding)

rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand Down Expand Up @@ -214,7 +215,7 @@ def linear_schedule(count):
def train(rng):
# INIT NETWORK
actor_network = ActorRNN(env.action_space(env.agents[0]).n, config=config)
critic_network = CriticRNN()
critic_network = CriticRNN(config=config)
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
ac_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).shape[0])),
Expand Down Expand Up @@ -317,7 +318,7 @@ def _env_step(runner_state, unused):
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
done_batch,
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down

0 comments on commit 0863270

Please sign in to comment.