Skip to content

Commit

Permalink
done fix for IPPO
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 27, 2024
1 parent a3b0c25 commit a8fb4ca
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 24 deletions.
2 changes: 2 additions & 0 deletions baselines/IPPO/config/ippo_rnn_hanabi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"NUM_ENVS": 1024
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e10
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
Expand Down
2 changes: 2 additions & 0 deletions baselines/IPPO/config/ippo_rnn_mpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"LR": 5e-4
"NUM_ENVS": 24
"NUM_STEPS": 128
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"TOTAL_TIMESTEPS": 2e6
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 2
Expand Down
5 changes: 3 additions & 2 deletions baselines/IPPO/config/ippo_rnn_smax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"NUM_ENVS": 128
"NUM_STEPS": 128
"GRU_HIDDEN_DIM": 256
"FC_DIM_SIZE": 128
"TOTAL_TIMESTEPS": 1e7
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
Expand All @@ -23,6 +24,6 @@
"ANNEAL_LR": True

# WandB Params
"ENTITY": ${oc.env:WANDB_ENTITY}
"ENTITY": "amacrutherford"
"PROJECT": "jaxmarl-smax"
"WANDB_MODE" : "disabled"
"WANDB_MODE" : "online"
40 changes: 30 additions & 10 deletions baselines/IPPO/ippo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ class ActorCriticRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones, avail_actions = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand All @@ -83,7 +83,7 @@ def __call__(self, hidden, x):
action_logits = actor_mean - (unavail_actions * 1e10)
pi = distrax.Categorical(logits=action_logits)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand All @@ -95,6 +95,7 @@ def __call__(self, hidden, x):


class Transition(NamedTuple):
global_done: jnp.ndarray
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
Expand Down Expand Up @@ -146,7 +147,7 @@ def train(rng):
jnp.zeros((1, config["NUM_ENVS"])),
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n))
)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
network_params = network.init(_rng, init_hstate, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
Expand All @@ -165,7 +166,7 @@ def train(rng):
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

# TRAIN LOOP
def _update_step(update_runner_state, unused):
Expand All @@ -186,7 +187,7 @@ def _env_step(runner_state, unused):
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)

env_act = jax.tree_map(lambda x: x.squeeze(), env_act)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
Expand All @@ -196,7 +197,8 @@ def _env_step(runner_state, unused):
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
done_batch,
jnp.tile(done["__all__"], env.num_agents),
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down Expand Up @@ -227,7 +229,7 @@ def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.global_done,
transition.value,
transition.reward,
)
Expand Down Expand Up @@ -271,7 +273,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
Expand All @@ -285,13 +288,17 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()

# debug
approx_kl = ((ratio - 1) - logratio).mean()
clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
Expand Down Expand Up @@ -337,6 +344,18 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
train_state = update_state[0]
metric = traj_batch.info
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
"actor_loss": loss_info[1][1],
"entropy": loss_info[1][2],
"ratio": loss_info[1][3],
"ratio_0": ratio_0,
"approx_kl": loss_info[1][4],
"clip_frac": loss_info[1][5],
}
rng = update_state[-1]

def callback(metric):
Expand All @@ -346,6 +365,7 @@ def callback(metric):
"env_step": metric["update_steps"]
* config["NUM_ENVS"]
* config["NUM_STEPS"],
**metric["loss"],
}
)
metric["update_steps"] = update_steps
Expand Down
30 changes: 24 additions & 6 deletions baselines/IPPO/ippo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __call__(self, hidden, x):
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand Down Expand Up @@ -142,7 +142,7 @@ def train(rng):
),
jnp.zeros((1, config["NUM_ENVS"])),
)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
network_params = network.init(_rng, init_hstate, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
Expand All @@ -164,7 +164,7 @@ def train(rng):
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

# TRAIN LOOP
def _update_step(update_runner_state, unused):
Expand Down Expand Up @@ -199,7 +199,7 @@ def _env_step(runner_state, unused):
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
done_batch,
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down Expand Up @@ -276,7 +276,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
).mean(where=(1 - traj_batch.done))

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
Expand All @@ -291,12 +292,16 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
loss_actor = loss_actor.mean(where=(1 - traj_batch.done))
entropy = pi.entropy().mean(where=(1 - traj_batch.done))

# debug
approx_kl = ((ratio - 1) - logratio).mean()
clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
Expand Down Expand Up @@ -376,6 +381,18 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
),
traj_batch.info,
)
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
"actor_loss": loss_info[1][1],
"entropy": loss_info[1][2],
"ratio": loss_info[1][3],
"ratio_0": ratio_0,
"approx_kl": loss_info[1][4],
"clip_frac": loss_info[1][5],
}
rng = update_state[-1]

def callback(metric):
Expand All @@ -389,6 +406,7 @@ def callback(metric):
"env_step": metric["update_steps"]
* config["NUM_ENVS"]
* config["NUM_STEPS"],
**metric["loss"],
}
)

Expand Down
31 changes: 25 additions & 6 deletions baselines/IPPO/ippo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ActorCriticRNN(nn.Module):
def __call__(self, hidden, x):
obs, dones, avail_actions = x
embedding = nn.Dense(
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
embedding = nn.relu(embedding)

Expand All @@ -77,7 +77,7 @@ def __call__(self, hidden, x):

pi = distrax.Categorical(logits=action_logits)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand Down Expand Up @@ -208,7 +208,7 @@ def _env_step(runner_state, unused):
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
done_batch,
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down Expand Up @@ -290,7 +290,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
).mean(where=(1 - traj_batch.done))

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
Expand All @@ -305,12 +306,16 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
loss_actor = loss_actor.mean(where=(1 - traj_batch.done))
entropy = pi.entropy().mean(where=(1 - traj_batch.done))

# debug
approx_kl = ((ratio - 1) - logratio).mean()
clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
Expand Down Expand Up @@ -373,7 +378,7 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):

update_state = (
train_state,
init_hstate,
initial_hstate,
traj_batch,
advantages,
targets,
Expand All @@ -390,6 +395,19 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
),
traj_batch.info,
)
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
"actor_loss": loss_info[1][1],
"entropy": loss_info[1][2],
"ratio": loss_info[1][3],
"ratio_0": ratio_0,
"approx_kl": loss_info[1][4],
"clip_frac": loss_info[1][5],
}

rng = update_state[-1]

def callback(metric):
Expand All @@ -406,6 +424,7 @@ def callback(metric):
"env_step": metric["update_steps"]
* config["NUM_ENVS"]
* config["NUM_STEPS"],
**metric["loss"],
}
)

Expand Down

0 comments on commit a8fb4ca

Please sign in to comment.