Skip to content

Commit

Permalink
terminated over done, additional debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 24, 2024
1 parent 8e9eb90 commit d423f96
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
2 changes: 1 addition & 1 deletion baselines/IPPO/config/ippo_rnn_smax_mava.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"MAX_GRAD_NORM": 0.5
"ACTIVATION": "relu"
"ENV_NAME": "HeuristicEnemySMAX"
"MAP_NAME": "3s5z"
"MAP_NAME": "27m_vs_30m"
"SEED": 42
"ANNEAL_LR": False
"ADD_AGENT_ID": True
Expand Down
30 changes: 20 additions & 10 deletions baselines/IPPO/ippo_rnn_smax_split_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def _update_minbatch(train_state: Tuple, batch_info: Tuple):

def _actor_loss_fn(
actor_params: FrozenDict,
actor_opt_state: OptState,
actor_init_hstate: chex.Array,
traj_batch,
gae: chex.Array
Expand All @@ -391,8 +390,9 @@ def _actor_loss_fn(
log_prob = pi.log_prob(traj_batch.action)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8) # NOTE should we also done mask this mean and std..?
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
gae = (gae - gae.mean(where=(1-traj_batch.terminated))) / (gae.std(where=(1-traj_batch.terminated)) + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
Expand All @@ -405,14 +405,21 @@ def _actor_loss_fn(
# TODO add back in done masking on the mean
# print('loss actor shape', loss_actor1.shape, 'term shape', traj_batch.terminated.shape)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean(where=(1-traj_batch.done))
entropy = pi.entropy().mean(where=(1-traj_batch.done))
loss_actor = loss_actor.mean(where=(1-traj_batch.terminated))
entropy = pi.entropy().mean(where=(1-traj_batch.terminated))

# debugging
approx_kl = jax.lax.stop_gradient(
((ratio - 1) - logratio).mean(where=(1-traj_batch.terminated))
)
clip_frac = jax.lax.stop_gradient(
(jnp.abs(ratio - 1) > config["CLIP_EPS"]).mean(where=(1-traj_batch.terminated))
)

total_loss = loss_actor - config["ENT_COEF"] * entropy
return total_loss, (loss_actor, entropy, ratio)
return total_loss, (loss_actor, entropy, ratio, approx_kl, clip_frac)

def _critic_loss_fn(critic_params: FrozenDict,
critic_opt_state: OptState,
critic_init_hstate: chex.Array,
traj_batch,
targets: chex.Array):
Expand All @@ -431,19 +438,19 @@ def _critic_loss_fn(critic_params: FrozenDict,
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(
value_losses, value_losses_clipped
).mean(where=(1-traj_batch.done))
).mean(where=(1-traj_batch.terminated))

total_loss = config["VF_COEF"] * value_loss
return total_loss, (value_loss)

actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
params.actor_params, opt_states.actor_opt_state, init_hstates.actor_hstate, traj_batch, advantages
params.actor_params, init_hstates.actor_hstate, traj_batch, advantages
)

critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
params.critic_params, opt_states.critic_opt_state, init_hstates.critic_hstate, traj_batch, targets
params.critic_params, init_hstates.critic_hstate, traj_batch, targets
)

actor_updates, actor_new_opt_state = actor_optim.update(
Expand All @@ -466,6 +473,8 @@ def _critic_loss_fn(critic_params: FrozenDict,
"actor_loss": actor_loss_info[1][0],
"entropy": actor_loss_info[1][1],
"ratio": actor_loss_info[1][2],
"approx_kl": actor_loss_info[1][3],
"clip_frac": actor_loss_info[1][4],
}

return (new_params, new_opt_state), loss_info
Expand Down Expand Up @@ -573,6 +582,7 @@ def callback(metric):

metric["update_steps"] = update_step
metric["loss_info"] = jax.tree_map(lambda x: x.mean(), loss_info)
metric["loss_info"]["ratio_0"] = loss_info["ratio"].at[0, 0].get().mean()
jax.experimental.io_callback(callback, None, metric)
runner_state = RNNRunnerState(
update_step + 1,
Expand Down

0 comments on commit d423f96

Please sign in to comment.