-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
434 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,391 @@ | ||
""" | ||
Based on PureJaxRL Implementation of PPO | ||
""" | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import flax.linen as nn | ||
import numpy as np | ||
import optax | ||
from flax.linen.initializers import constant, orthogonal | ||
from typing import Sequence, NamedTuple, Any, Dict | ||
from flax.training.train_state import TrainState | ||
import distrax | ||
from jaxmarl.wrappers.baselines import LogWrapper | ||
import jaxmarl | ||
import wandb | ||
import functools | ||
import matplotlib.pyplot as plt | ||
import hydra | ||
from omegaconf import OmegaConf | ||
|
||
|
||
class EncoderBlock(nn.Module): | ||
hidden_dim : int = 32 # Input dimension is needed here since it is equal to the output dimension (residual connection) | ||
num_heads : int = 4 | ||
dim_feedforward : int = 256 | ||
dropout_prob : float = 0. | ||
|
||
def setup(self): | ||
# Attention layer | ||
self.self_attn = nn.MultiHeadDotProductAttention( | ||
num_heads=self.num_heads, | ||
dropout_rate=self.dropout_prob, | ||
kernel_init=nn.initializers.xavier_uniform(), | ||
use_bias=False, | ||
) | ||
# Two-layer MLP | ||
self.linear = [ | ||
nn.Dense(self.dim_feedforward, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0)), | ||
nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0)) | ||
] | ||
# Layers to apply in between the main layers | ||
self.norm1 = nn.LayerNorm() | ||
self.norm2 = nn.LayerNorm() | ||
self.dropout = nn.Dropout(self.dropout_prob) | ||
|
||
def __call__(self, x, mask=None, deterministic=True): | ||
|
||
# Attention part | ||
if mask is not None and not self.use_fast_attention: # masking is not compatible with fast self attention | ||
mask = jnp.repeat(nn.make_attention_mask(mask, mask), self.num_heads, axis=-3) | ||
attended = self.self_attn(inputs_q=x, inputs_kv=x, mask=mask, deterministic=deterministic) | ||
|
||
x = self.norm1(attended + x) | ||
x = x + self.dropout(x, deterministic=deterministic) | ||
|
||
# MLP part | ||
feedforward = self.linear[0](x) | ||
feedforward = nn.relu(feedforward) | ||
feedforward = self.linear[1](feedforward) | ||
|
||
x = self.norm2(feedforward+x) | ||
x = x + self.dropout(x, deterministic=deterministic) | ||
|
||
return x | ||
|
||
|
||
class ActorCritic(nn.Module): | ||
action_dim: Sequence[int] | ||
config: Dict | ||
transf_layers: int = 2 | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
obs, dones, avail_actions = x | ||
obs_embedding = nn.Dense( | ||
32, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | ||
)(obs) | ||
obs_embedding = nn.relu(obs_embedding) | ||
|
||
actions_embeddings = nn.Embed(num_embeddings=self.action_dim, features=32)(jnp.arange(self.action_dim)) | ||
actions_embeddings = jnp.tile(actions_embeddings, (*obs_embedding.shape[:-1],1, 1)) | ||
embeddings = jnp.concatenate((obs_embedding[..., np.newaxis, :], actions_embeddings), axis=-2) | ||
|
||
for _ in range(self.transf_layers): | ||
embeddings = EncoderBlock()(embeddings) | ||
|
||
obs_embedding_post = embeddings[..., 0, :] | ||
actions_embeddings_post = embeddings[..., 1:, :] | ||
|
||
#actor_mean = nn.Dense(512, kernel_init=orthogonal(2), bias_init=constant(0.0))(actions_embeddings_post) | ||
#actor_mean = nn.relu(actor_mean) | ||
actor_mean = nn.Dense(1, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actions_embeddings_post) | ||
actor_mean = jnp.squeeze(actor_mean, axis=-1) | ||
|
||
unavail_actions = 1 - avail_actions | ||
action_logits = actor_mean - (unavail_actions * 1e10) | ||
pi = distrax.Categorical(logits=action_logits) | ||
|
||
critic = nn.Dense(512, kernel_init=orthogonal(2), bias_init=constant(0.0))(obs_embedding_post) | ||
critic = nn.relu(critic) | ||
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( | ||
critic | ||
) | ||
|
||
return pi, jnp.squeeze(critic, axis=-1) | ||
|
||
|
||
class Transition(NamedTuple): | ||
done: jnp.ndarray | ||
action: jnp.ndarray | ||
value: jnp.ndarray | ||
reward: jnp.ndarray | ||
log_prob: jnp.ndarray | ||
obs: jnp.ndarray | ||
info: jnp.ndarray | ||
avail_actions: jnp.ndarray | ||
|
||
def batchify(x: dict, agent_list, num_actors): | ||
x = jnp.stack([x[a] for a in agent_list]) | ||
return x.reshape((num_actors, -1)) | ||
|
||
def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors): | ||
x = x.reshape((num_actors, num_envs, -1)) | ||
return {a: x[i] for i, a in enumerate(agent_list)} | ||
|
||
def make_train(config): | ||
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"]) | ||
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"] | ||
config["NUM_UPDATES"] = ( | ||
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] | ||
) | ||
config["MINIBATCH_SIZE"] = ( | ||
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] | ||
) | ||
|
||
env = LogWrapper(env) | ||
|
||
def linear_schedule(count): | ||
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"] | ||
return config["LR"] * frac | ||
|
||
def train(rng): | ||
|
||
# INIT NETWORK | ||
network = ActorCritic(env.action_space(env.agents[0]).n, config=config) | ||
rng, _rng = jax.random.split(rng) | ||
init_x = ( | ||
jnp.zeros( | ||
(1, config["NUM_ENVS"], env.observation_space(env.agents[0]).n) | ||
), | ||
jnp.zeros((1, config["NUM_ENVS"])), | ||
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)) | ||
) | ||
network_params = network.init(_rng, init_x) | ||
if config["ANNEAL_LR"]: | ||
tx = optax.chain( | ||
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), | ||
optax.adam(learning_rate=linear_schedule, eps=1e-5), | ||
) | ||
else: | ||
tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5)) | ||
train_state = TrainState.create( | ||
apply_fn=network.apply, | ||
params=network_params, | ||
tx=tx, | ||
) | ||
|
||
# INIT ENV | ||
rng, _rng = jax.random.split(rng) | ||
reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) | ||
obsv, env_state = jax.vmap(env.reset, in_axes=(0))(reset_rng) | ||
|
||
# TRAIN LOOP | ||
def _update_step(update_runner_state, unused): | ||
# COLLECT TRAJECTORIES | ||
runner_state, update_steps = update_runner_state | ||
def _env_step(runner_state, unused): | ||
train_state, env_state, last_obs, last_done, rng = runner_state | ||
|
||
# SELECT ACTION | ||
rng, _rng = jax.random.split(rng) | ||
avail_actions = jax.vmap(env.get_legal_moves)(env_state.env_state) | ||
avail_actions = jax.lax.stop_gradient( | ||
batchify(avail_actions, env.agents, config["NUM_ACTORS"]) | ||
) | ||
obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"]) | ||
ac_in = (obs_batch[np.newaxis, :], last_done[np.newaxis, :], avail_actions[np.newaxis, :]) | ||
pi, value = network.apply(train_state.params, ac_in) | ||
action = pi.sample(seed=_rng) | ||
log_prob = pi.log_prob(action) | ||
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents) | ||
|
||
# STEP ENV | ||
rng, _rng = jax.random.split(rng) | ||
rng_step = jax.random.split(_rng, config["NUM_ENVS"]) | ||
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))( | ||
rng_step, env_state, env_act | ||
) | ||
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) | ||
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() | ||
transition = Transition( | ||
done_batch, | ||
action.squeeze(), | ||
value.squeeze(), | ||
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(), | ||
log_prob.squeeze(), | ||
obs_batch, | ||
info, | ||
avail_actions | ||
) | ||
runner_state = (train_state, env_state, obsv, done_batch, rng) | ||
return runner_state, transition | ||
|
||
runner_state, traj_batch = jax.lax.scan( | ||
_env_step, runner_state, None, config["NUM_STEPS"] | ||
) | ||
|
||
# CALCULATE ADVANTAGE | ||
train_state, env_state, last_obs, last_done, rng = runner_state | ||
last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"]) | ||
avail_actions = jnp.ones( | ||
(config["NUM_ACTORS"], env.action_space(env.agents[0]).n) | ||
) | ||
ac_in = (last_obs_batch[np.newaxis, :], last_done[np.newaxis, :], avail_actions) | ||
_, last_val = network.apply(train_state.params, ac_in) | ||
last_val = last_val.squeeze() | ||
|
||
def _calculate_gae(traj_batch, last_val): | ||
def _get_advantages(gae_and_next_value, transition): | ||
gae, next_value = gae_and_next_value | ||
done, value, reward = ( | ||
transition.done, | ||
transition.value, | ||
transition.reward, | ||
) | ||
delta = reward + config["GAMMA"] * next_value * (1 - done) - value | ||
gae = ( | ||
delta | ||
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae | ||
) | ||
return (gae, value), gae | ||
|
||
_, advantages = jax.lax.scan( | ||
_get_advantages, | ||
(jnp.zeros_like(last_val), last_val), | ||
traj_batch, | ||
reverse=True, | ||
unroll=16, | ||
) | ||
return advantages, advantages + traj_batch.value | ||
|
||
advantages, targets = _calculate_gae(traj_batch, last_val) | ||
|
||
# UPDATE NETWORK | ||
def _update_epoch(update_state, unused): | ||
def _update_minbatch(train_state, batch_info): | ||
traj_batch, advantages, targets = batch_info | ||
|
||
def _loss_fn(params, traj_batch, gae, targets): | ||
# RERUN NETWORK | ||
pi, value = network.apply(params, | ||
(traj_batch.obs, traj_batch.done, traj_batch.avail_actions)) | ||
log_prob = pi.log_prob(traj_batch.action) | ||
|
||
# CALCULATE VALUE LOSS | ||
value_pred_clipped = traj_batch.value + ( | ||
value - traj_batch.value | ||
).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) | ||
value_losses = jnp.square(value - targets) | ||
value_losses_clipped = jnp.square(value_pred_clipped - targets) | ||
value_loss = ( | ||
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() | ||
) | ||
|
||
# CALCULATE ACTOR LOSS | ||
ratio = jnp.exp(log_prob - traj_batch.log_prob) | ||
gae = (gae - gae.mean()) / (gae.std() + 1e-8) | ||
loss_actor1 = ratio * gae | ||
loss_actor2 = ( | ||
jnp.clip( | ||
ratio, | ||
1.0 - config["CLIP_EPS"], | ||
1.0 + config["CLIP_EPS"], | ||
) | ||
* gae | ||
) | ||
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | ||
loss_actor = loss_actor.mean() | ||
entropy = pi.entropy().mean() | ||
|
||
total_loss = ( | ||
loss_actor | ||
+ config["VF_COEF"] * value_loss | ||
- config["ENT_COEF"] * entropy | ||
) | ||
return total_loss, (value_loss, loss_actor, entropy) | ||
|
||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | ||
total_loss, grads = grad_fn( | ||
train_state.params, traj_batch, advantages, targets | ||
) | ||
train_state = train_state.apply_gradients(grads=grads) | ||
return train_state, total_loss | ||
|
||
train_state, traj_batch, advantages, targets, rng = update_state | ||
rng, _rng = jax.random.split(rng) | ||
|
||
batch = (traj_batch, advantages.squeeze(), targets.squeeze()) | ||
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) | ||
|
||
shuffled_batch = jax.tree_util.tree_map( | ||
lambda x: jnp.take(x, permutation, axis=1), batch | ||
) | ||
|
||
minibatches = jax.tree_util.tree_map( | ||
lambda x: jnp.swapaxes( | ||
jnp.reshape( | ||
x, | ||
[x.shape[0], config["NUM_MINIBATCHES"], -1] | ||
+ list(x.shape[2:]), | ||
), | ||
1, | ||
0, | ||
), | ||
shuffled_batch, | ||
) | ||
|
||
train_state, total_loss = jax.lax.scan( | ||
_update_minbatch, train_state, minibatches | ||
) | ||
update_state = (train_state, traj_batch, advantages, targets, rng) | ||
return update_state, total_loss | ||
|
||
update_state = (train_state, traj_batch, advantages, targets, rng) | ||
update_state, loss_info = jax.lax.scan( | ||
_update_epoch, update_state, None, config["UPDATE_EPOCHS"] | ||
) | ||
train_state = update_state[0] | ||
metric = traj_batch.info | ||
rng = update_state[-1] | ||
|
||
def callback(metric): | ||
wandb.log( | ||
{ | ||
"returns": metric["returned_episode_returns"][-1, :].mean(), | ||
"env_step": metric["update_steps"] | ||
* config["NUM_ENVS"] | ||
* config["NUM_STEPS"], | ||
} | ||
) | ||
metric["update_steps"] = update_steps | ||
jax.experimental.io_callback(callback, None, metric) | ||
update_steps = update_steps + 1 | ||
runner_state = (train_state, env_state, last_obs, last_done, rng) | ||
return (runner_state, update_steps), None | ||
|
||
rng, _rng = jax.random.split(rng) | ||
runner_state = (train_state, env_state, obsv, jnp.zeros((config["NUM_ACTORS"]), dtype=bool), _rng) | ||
runner_state, _ = jax.lax.scan( | ||
_update_step, (runner_state, 0), None, config["NUM_UPDATES"] | ||
) | ||
return {"runner_state": runner_state} | ||
|
||
return train | ||
|
||
@hydra.main(version_base=None, config_path="config", config_name="ippo_ff_hanabi") | ||
def main(config): | ||
config = OmegaConf.to_container(config) | ||
|
||
wandb.init( | ||
entity=config["ENTITY"], | ||
project=config["PROJECT"], | ||
tags=["IPPO", "FF", config["ENV_NAME"]], | ||
config=config, | ||
mode=config["WANDB_MODE"], | ||
) | ||
|
||
rng = jax.random.PRNGKey(50) | ||
train_jit = jax.jit(make_train(config), device=jax.devices()[0]) | ||
out = train_jit(rng) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
'''results = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1) | ||
jnp.save('hanabi_results', results) | ||
plt.plot(results) | ||
plt.xlabel("Update Step") | ||
plt.ylabel("Return") | ||
plt.savefig(f'IPPO_{config["ENV_NAME"]}.png')''' |
Oops, something went wrong.