diff --git a/baselines/IPPO/ippo_rnn_smax_split_networks.py b/baselines/IPPO/ippo_rnn_smax_split_networks.py index b90c8070..f05b5721 100644 --- a/baselines/IPPO/ippo_rnn_smax_split_networks.py +++ b/baselines/IPPO/ippo_rnn_smax_split_networks.py @@ -1,5 +1,8 @@ """ -Based on PureJaxRL Implementation of PPO +Based on PureJaxRL Implementation of PPO. + +NOTE: ppo loss functions currently using `done` masking when computing means. + """ import jax @@ -102,7 +105,7 @@ def __call__(self, hidden, x): rnn_in = (embedding, dones) hidden, embedding = ScannedRNN()(hidden, rnn_in) - critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( + critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))( embedding ) critic = nn.relu(critic) @@ -297,10 +300,7 @@ def _env_step(runner_state: RNNRunnerState, unused): episode_done = done["__all__"] done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() - ## term - mask out timesteps when the agent didn't act - # if agent was done in the last timestep and the episode is not over then this action is invalid - # but if the episode ends, then this action is valid - + ## term - mask out timesteps when the agent didn't act last_done = last_done.reshape((config["NUM_ENVS"], -1)) last_ep_done = jnp.all(last_done, axis=1) term = last_done & ~last_ep_done[:, None] @@ -316,7 +316,7 @@ def _env_step(runner_state: RNNRunnerState, unused): obs_batch, info, avail_actions, - term + term, ) hstates = HiddenStates(actor_hstate, critic_hstate) @@ -340,7 +340,6 @@ def _env_step(runner_state: RNNRunnerState, unused): _, last_val = critic_network.apply(params.critic_params, hstates.critic_hstate, ac_in) last_val = last_val.squeeze() # mava here masks out the terminal states but surely unnecessary? last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) - # TODO mava's masking def _calculate_gae(traj_batch, last_val): def _get_advantages(gae_and_next_value, transition): @@ -393,7 +392,7 @@ def _actor_loss_fn( # CALCULATE ACTOR LOSS ratio = jnp.exp(log_prob - traj_batch.log_prob) - gae = (gae - gae.mean()) / (gae.std() + 1e-8) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) # NOTE should we also done mask this mean and std..? loss_actor1 = ratio * gae loss_actor2 = ( jnp.clip( @@ -404,7 +403,7 @@ def _actor_loss_fn( * gae ) # TODO add back in done masking on the mean - print('loss actor shape', loss_actor1.shape, 'term shape', traj_batch.terminated.shape) + # print('loss actor shape', loss_actor1.shape, 'term shape', traj_batch.terminated.shape) loss_actor = -jnp.minimum(loss_actor1, loss_actor2) loss_actor = loss_actor.mean(where=(1-traj_batch.done)) entropy = pi.entropy().mean(where=(1-traj_batch.done)) @@ -483,9 +482,6 @@ def _critic_loss_fn(critic_params: FrozenDict, rng, shuffle_rng = jax.random.split(rng) # adding an additional "fake" dimensionality to perform minibatching correctly - # NOTE: mava does this differently, reshaping into "recurrent chunk size" and then permuting - # TODO: look at the shapes of mava and jaxmarl - # NOTE: mava also handles hstates differently, use the one stored in traj_batch init_hstates = jax.tree_map( lambda x: jnp.reshape( x, (1, config["NUM_ACTORS"], -1) @@ -499,8 +495,7 @@ def _critic_loss_fn(critic_params: FrozenDict, targets.squeeze(), ) permutation = jax.random.permutation(shuffle_rng, config["NUM_ACTORS"]) - print('batch advantage shape', batch[2].shape) - # advantage shape is (num_steps, num_actors) + shuffled_batch = jax.tree_util.tree_map( lambda x: jnp.take(x, permutation, axis=1), batch ) @@ -517,8 +512,6 @@ def _critic_loss_fn(critic_params: FrozenDict, ), shuffled_batch, ) - print('minibatch advantage shape', minibatches[2].shape) - # advantage shape is (num_minibatches, num_steps, -1) (params, opt_states), loss_info = jax.lax.scan( _update_minbatch, (params, opt_states), minibatches