From a8fb4ca07ecc1e7c097e645bd55c92443d54aff6 Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 27 Mar 2024 12:41:03 +0000 Subject: [PATCH] done fix for IPPO --- baselines/IPPO/config/ippo_rnn_hanabi.yaml | 2 ++ baselines/IPPO/config/ippo_rnn_mpe.yaml | 2 ++ baselines/IPPO/config/ippo_rnn_smax.yaml | 5 +-- baselines/IPPO/ippo_rnn_hanabi.py | 40 ++++++++++++++++------ baselines/IPPO/ippo_rnn_mpe.py | 30 ++++++++++++---- baselines/IPPO/ippo_rnn_smax.py | 31 +++++++++++++---- 6 files changed, 86 insertions(+), 24 deletions(-) diff --git a/baselines/IPPO/config/ippo_rnn_hanabi.yaml b/baselines/IPPO/config/ippo_rnn_hanabi.yaml index b118ef88..4c224ce8 100644 --- a/baselines/IPPO/config/ippo_rnn_hanabi.yaml +++ b/baselines/IPPO/config/ippo_rnn_hanabi.yaml @@ -2,6 +2,8 @@ "NUM_ENVS": 1024 "NUM_STEPS": 128 "TOTAL_TIMESTEPS": 1e10 +"FC_DIM_SIZE": 128 +"GRU_HIDDEN_DIM": 128 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 "GAMMA": 0.99 diff --git a/baselines/IPPO/config/ippo_rnn_mpe.yaml b/baselines/IPPO/config/ippo_rnn_mpe.yaml index 87bfde63..0946a937 100644 --- a/baselines/IPPO/config/ippo_rnn_mpe.yaml +++ b/baselines/IPPO/config/ippo_rnn_mpe.yaml @@ -2,6 +2,8 @@ "LR": 5e-4 "NUM_ENVS": 24 "NUM_STEPS": 128 +"FC_DIM_SIZE": 128 +"GRU_HIDDEN_DIM": 128 "TOTAL_TIMESTEPS": 2e6 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 2 diff --git a/baselines/IPPO/config/ippo_rnn_smax.yaml b/baselines/IPPO/config/ippo_rnn_smax.yaml index 9160b61d..a059fb5d 100644 --- a/baselines/IPPO/config/ippo_rnn_smax.yaml +++ b/baselines/IPPO/config/ippo_rnn_smax.yaml @@ -2,6 +2,7 @@ "NUM_ENVS": 128 "NUM_STEPS": 128 "GRU_HIDDEN_DIM": 256 +"FC_DIM_SIZE": 128 "TOTAL_TIMESTEPS": 1e7 "UPDATE_EPOCHS": 4 "NUM_MINIBATCHES": 4 @@ -23,6 +24,6 @@ "ANNEAL_LR": True # WandB Params -"ENTITY": ${oc.env:WANDB_ENTITY} +"ENTITY": "amacrutherford" "PROJECT": "jaxmarl-smax" -"WANDB_MODE" : "disabled" +"WANDB_MODE" : "online" diff --git a/baselines/IPPO/ippo_rnn_hanabi.py b/baselines/IPPO/ippo_rnn_hanabi.py index ed185f58..359f5259 100644 --- a/baselines/IPPO/ippo_rnn_hanabi.py +++ b/baselines/IPPO/ippo_rnn_hanabi.py @@ -65,14 +65,14 @@ class ActorCriticRNN(nn.Module): def __call__(self, hidden, x): obs, dones, avail_actions = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) actor_mean = nn.relu(actor_mean) @@ -83,7 +83,7 @@ def __call__(self, hidden, x): action_logits = actor_mean - (unavail_actions * 1e10) pi = distrax.Categorical(logits=action_logits) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) @@ -95,6 +95,7 @@ def __call__(self, hidden, x): class Transition(NamedTuple): + global_done: jnp.ndarray done: jnp.ndarray action: jnp.ndarray value: jnp.ndarray @@ -146,7 +147,7 @@ def train(rng): jnp.zeros((1, config["NUM_ENVS"])), jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)) ) - init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128) + init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"]) network_params = network.init(_rng, init_hstate, init_x) if config["ANNEAL_LR"]: tx = optax.chain( @@ -165,7 +166,7 @@ def train(rng): rng, _rng = jax.random.split(rng) reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng) - init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128) + init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"]) # TRAIN LOOP def _update_step(update_runner_state, unused): @@ -186,7 +187,7 @@ def _env_step(runner_state, unused): action = pi.sample(seed=_rng) log_prob = pi.log_prob(action) env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents) - + env_act = jax.tree_map(lambda x: x.squeeze(), env_act) # STEP ENV rng, _rng = jax.random.split(rng) rng_step = jax.random.split(_rng, config["NUM_ENVS"]) @@ -196,7 +197,8 @@ def _env_step(runner_state, unused): info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( - done_batch, + jnp.tile(done["__all__"], env.num_agents), + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), @@ -227,7 +229,7 @@ def _calculate_gae(traj_batch, last_val): def _get_advantages(gae_and_next_value, transition): gae, next_value = gae_and_next_value done, value, reward = ( - transition.done, + transition.global_done, transition.value, transition.reward, ) @@ -271,7 +273,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) + logratio = log_prob - traj_batch.log_prob + ratio = jnp.exp(logratio) gae = (gae - gae.mean()) / (gae.std() + 1e-8) loss_actor1 = ratio * gae loss_actor2 = ( @@ -285,13 +288,17 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): loss_actor = -jnp.minimum(loss_actor1, loss_actor2) loss_actor = loss_actor.mean() entropy = pi.entropy().mean() + + # debug + approx_kl = ((ratio - 1) - logratio).mean() + clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"]) total_loss = ( loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy ) - return total_loss, (value_loss, loss_actor, entropy) + return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn( @@ -337,6 +344,18 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) train_state = update_state[0] metric = traj_batch.info + ratio_0 = loss_info[1][3].at[0,0].get().mean() + loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + metric["loss"] = { + "total_loss": loss_info[0], + "value_loss": loss_info[1][0], + "actor_loss": loss_info[1][1], + "entropy": loss_info[1][2], + "ratio": loss_info[1][3], + "ratio_0": ratio_0, + "approx_kl": loss_info[1][4], + "clip_frac": loss_info[1][5], + } rng = update_state[-1] def callback(metric): @@ -346,6 +365,7 @@ def callback(metric): "env_step": metric["update_steps"] * config["NUM_ENVS"] * config["NUM_STEPS"], + **metric["loss"], } ) metric["update_steps"] = update_steps diff --git a/baselines/IPPO/ippo_rnn_mpe.py b/baselines/IPPO/ippo_rnn_mpe.py index a5a81d1f..1a3ba809 100644 --- a/baselines/IPPO/ippo_rnn_mpe.py +++ b/baselines/IPPO/ippo_rnn_mpe.py @@ -64,7 +64,7 @@ def __call__(self, hidden, x): rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) actor_mean = nn.relu(actor_mean) @@ -142,7 +142,7 @@ def train(rng): ), jnp.zeros((1, config["NUM_ENVS"])), ) - init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128) + init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"]) network_params = network.init(_rng, init_hstate, init_x) if config["ANNEAL_LR"]: tx = optax.chain( @@ -164,7 +164,7 @@ def train(rng): rng, _rng = jax.random.split(rng) reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng) - init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128) + init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"]) # TRAIN LOOP def _update_step(update_runner_state, unused): @@ -199,7 +199,7 @@ def _env_step(runner_state, unused): done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), - done_batch, + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), @@ -276,7 +276,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ).mean(where=(1 - traj_batch.done)) # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) + logratio = log_prob - traj_batch.log_prob + ratio = jnp.exp(logratio) gae = (gae - gae.mean()) / (gae.std() + 1e-8) loss_actor1 = ratio * gae loss_actor2 = ( @@ -291,12 +292,16 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): loss_actor = loss_actor.mean(where=(1 - traj_batch.done)) entropy = pi.entropy().mean(where=(1 - traj_batch.done)) + # debug + approx_kl = ((ratio - 1) - logratio).mean() + clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"]) + total_loss = ( loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy ) - return total_loss, (value_loss, loss_actor, entropy) + return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn( @@ -376,6 +381,18 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ), traj_batch.info, ) + ratio_0 = loss_info[1][3].at[0,0].get().mean() + loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + metric["loss"] = { + "total_loss": loss_info[0], + "value_loss": loss_info[1][0], + "actor_loss": loss_info[1][1], + "entropy": loss_info[1][2], + "ratio": loss_info[1][3], + "ratio_0": ratio_0, + "approx_kl": loss_info[1][4], + "clip_frac": loss_info[1][5], + } rng = update_state[-1] def callback(metric): @@ -389,6 +406,7 @@ def callback(metric): "env_step": metric["update_steps"] * config["NUM_ENVS"] * config["NUM_STEPS"], + **metric["loss"], } ) diff --git a/baselines/IPPO/ippo_rnn_smax.py b/baselines/IPPO/ippo_rnn_smax.py index d92325b0..5c987e23 100644 --- a/baselines/IPPO/ippo_rnn_smax.py +++ b/baselines/IPPO/ippo_rnn_smax.py @@ -58,7 +58,7 @@ class ActorCriticRNN(nn.Module): def __call__(self, hidden, x): obs, dones, avail_actions = x embedding = nn.Dense( - 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) )(obs) embedding = nn.relu(embedding) @@ -77,7 +77,7 @@ def __call__(self, hidden, x): pi = distrax.Categorical(logits=action_logits) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) @@ -208,7 +208,7 @@ def _env_step(runner_state, unused): done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), - done_batch, + last_done, action.squeeze(), value.squeeze(), batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), @@ -290,7 +290,8 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ).mean(where=(1 - traj_batch.done)) # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) + logratio = log_prob - traj_batch.log_prob + ratio = jnp.exp(logratio) gae = (gae - gae.mean()) / (gae.std() + 1e-8) loss_actor1 = ratio * gae loss_actor2 = ( @@ -305,12 +306,16 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): loss_actor = loss_actor.mean(where=(1 - traj_batch.done)) entropy = pi.entropy().mean(where=(1 - traj_batch.done)) + # debug + approx_kl = ((ratio - 1) - logratio).mean() + clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"]) + total_loss = ( loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy ) - return total_loss, (value_loss, loss_actor, entropy) + return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn( @@ -373,7 +378,7 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): update_state = ( train_state, - init_hstate, + initial_hstate, traj_batch, advantages, targets, @@ -390,6 +395,19 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ), traj_batch.info, ) + ratio_0 = loss_info[1][3].at[0,0].get().mean() + loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + metric["loss"] = { + "total_loss": loss_info[0], + "value_loss": loss_info[1][0], + "actor_loss": loss_info[1][1], + "entropy": loss_info[1][2], + "ratio": loss_info[1][3], + "ratio_0": ratio_0, + "approx_kl": loss_info[1][4], + "clip_frac": loss_info[1][5], + } + rng = update_state[-1] def callback(metric): @@ -406,6 +424,7 @@ def callback(metric): "env_step": metric["update_steps"] * config["NUM_ENVS"] * config["NUM_STEPS"], + **metric["loss"], } )