Skip to content

Commit

Permalink
tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 21, 2024
1 parent 15b6788 commit 13961d5
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions baselines/IPPO/ippo_rnn_smax_split_networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""
Based on PureJaxRL Implementation of PPO
Based on PureJaxRL Implementation of PPO.
NOTE: ppo loss functions currently using `done` masking when computing means.
"""

import jax
Expand Down Expand Up @@ -102,7 +105,7 @@ def __call__(self, hidden, x):
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
critic = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
critic = nn.relu(critic)
Expand Down Expand Up @@ -297,10 +300,7 @@ def _env_step(runner_state: RNNRunnerState, unused):
episode_done = done["__all__"]
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()

## term - mask out timesteps when the agent didn't act
# if agent was done in the last timestep and the episode is not over then this action is invalid
# but if the episode ends, then this action is valid

## term - mask out timesteps when the agent didn't act
last_done = last_done.reshape((config["NUM_ENVS"], -1))
last_ep_done = jnp.all(last_done, axis=1)
term = last_done & ~last_ep_done[:, None]
Expand All @@ -316,7 +316,7 @@ def _env_step(runner_state: RNNRunnerState, unused):
obs_batch,
info,
avail_actions,
term
term,
)

hstates = HiddenStates(actor_hstate, critic_hstate)
Expand All @@ -340,7 +340,6 @@ def _env_step(runner_state: RNNRunnerState, unused):
_, last_val = critic_network.apply(params.critic_params, hstates.critic_hstate, ac_in)
last_val = last_val.squeeze() # mava here masks out the terminal states but surely unnecessary?
last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val)
# TODO mava's masking

def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
Expand Down Expand Up @@ -393,7 +392,7 @@ def _actor_loss_fn(

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
gae = (gae - gae.mean()) / (gae.std() + 1e-8) # NOTE should we also done mask this mean and std..?
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
Expand All @@ -404,7 +403,7 @@ def _actor_loss_fn(
* gae
)
# TODO add back in done masking on the mean
print('loss actor shape', loss_actor1.shape, 'term shape', traj_batch.terminated.shape)
# print('loss actor shape', loss_actor1.shape, 'term shape', traj_batch.terminated.shape)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean(where=(1-traj_batch.done))
entropy = pi.entropy().mean(where=(1-traj_batch.done))
Expand Down Expand Up @@ -483,9 +482,6 @@ def _critic_loss_fn(critic_params: FrozenDict,
rng, shuffle_rng = jax.random.split(rng)

# adding an additional "fake" dimensionality to perform minibatching correctly
# NOTE: mava does this differently, reshaping into "recurrent chunk size" and then permuting
# TODO: look at the shapes of mava and jaxmarl
# NOTE: mava also handles hstates differently, use the one stored in traj_batch
init_hstates = jax.tree_map(
lambda x: jnp.reshape(
x, (1, config["NUM_ACTORS"], -1)
Expand All @@ -499,8 +495,7 @@ def _critic_loss_fn(critic_params: FrozenDict,
targets.squeeze(),
)
permutation = jax.random.permutation(shuffle_rng, config["NUM_ACTORS"])
print('batch advantage shape', batch[2].shape)
# advantage shape is (num_steps, num_actors)

shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=1), batch
)
Expand All @@ -517,8 +512,6 @@ def _critic_loss_fn(critic_params: FrozenDict,
),
shuffled_batch,
)
print('minibatch advantage shape', minibatches[2].shape)
# advantage shape is (num_minibatches, num_steps, -1)

(params, opt_states), loss_info = jax.lax.scan(
_update_minbatch, (params, opt_states), minibatches
Expand Down

0 comments on commit 13961d5

Please sign in to comment.