Skip to content

Commit

Permalink
vanilla ippo transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 10, 2024
1 parent aae9f99 commit 3b5b12b
Show file tree
Hide file tree
Showing 2 changed files with 434 additions and 8 deletions.
391 changes: 391 additions & 0 deletions baselines/IPPO/ippo_transf_hanabi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
"""
Based on PureJaxRL Implementation of PPO
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
from flax.training.train_state import TrainState
import distrax
from jaxmarl.wrappers.baselines import LogWrapper
import jaxmarl
import wandb
import functools
import matplotlib.pyplot as plt
import hydra
from omegaconf import OmegaConf


class EncoderBlock(nn.Module):
hidden_dim : int = 32 # Input dimension is needed here since it is equal to the output dimension (residual connection)
num_heads : int = 4
dim_feedforward : int = 256
dropout_prob : float = 0.

def setup(self):
# Attention layer
self.self_attn = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dropout_rate=self.dropout_prob,
kernel_init=nn.initializers.xavier_uniform(),
use_bias=False,
)
# Two-layer MLP
self.linear = [
nn.Dense(self.dim_feedforward, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0)),
nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0))
]
# Layers to apply in between the main layers
self.norm1 = nn.LayerNorm()
self.norm2 = nn.LayerNorm()
self.dropout = nn.Dropout(self.dropout_prob)

def __call__(self, x, mask=None, deterministic=True):

# Attention part
if mask is not None and not self.use_fast_attention: # masking is not compatible with fast self attention
mask = jnp.repeat(nn.make_attention_mask(mask, mask), self.num_heads, axis=-3)
attended = self.self_attn(inputs_q=x, inputs_kv=x, mask=mask, deterministic=deterministic)

x = self.norm1(attended + x)
x = x + self.dropout(x, deterministic=deterministic)

# MLP part
feedforward = self.linear[0](x)
feedforward = nn.relu(feedforward)
feedforward = self.linear[1](feedforward)

x = self.norm2(feedforward+x)
x = x + self.dropout(x, deterministic=deterministic)

return x


class ActorCritic(nn.Module):
action_dim: Sequence[int]
config: Dict
transf_layers: int = 2

@nn.compact
def __call__(self, x):
obs, dones, avail_actions = x
obs_embedding = nn.Dense(
32, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
obs_embedding = nn.relu(obs_embedding)

actions_embeddings = nn.Embed(num_embeddings=self.action_dim, features=32)(jnp.arange(self.action_dim))
actions_embeddings = jnp.tile(actions_embeddings, (*obs_embedding.shape[:-1],1, 1))
embeddings = jnp.concatenate((obs_embedding[..., np.newaxis, :], actions_embeddings), axis=-2)

for _ in range(self.transf_layers):
embeddings = EncoderBlock()(embeddings)

obs_embedding_post = embeddings[..., 0, :]
actions_embeddings_post = embeddings[..., 1:, :]

#actor_mean = nn.Dense(512, kernel_init=orthogonal(2), bias_init=constant(0.0))(actions_embeddings_post)
#actor_mean = nn.relu(actor_mean)
actor_mean = nn.Dense(1, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actions_embeddings_post)
actor_mean = jnp.squeeze(actor_mean, axis=-1)

unavail_actions = 1 - avail_actions
action_logits = actor_mean - (unavail_actions * 1e10)
pi = distrax.Categorical(logits=action_logits)

critic = nn.Dense(512, kernel_init=orthogonal(2), bias_init=constant(0.0))(obs_embedding_post)
critic = nn.relu(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)

return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
avail_actions: jnp.ndarray

def batchify(x: dict, agent_list, num_actors):
x = jnp.stack([x[a] for a in agent_list])
return x.reshape((num_actors, -1))

def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
x = x.reshape((num_actors, num_envs, -1))
return {a: x[i] for i, a in enumerate(agent_list)}

def make_train(config):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)

env = LogWrapper(env)

def linear_schedule(count):
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
return config["LR"] * frac

def train(rng):

# INIT NETWORK
network = ActorCritic(env.action_space(env.agents[0]).n, config=config)
rng, _rng = jax.random.split(rng)
init_x = (
jnp.zeros(
(1, config["NUM_ENVS"], env.observation_space(env.agents[0]).n)
),
jnp.zeros((1, config["NUM_ENVS"])),
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n))
)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0))(reset_rng)

# TRAIN LOOP
def _update_step(update_runner_state, unused):
# COLLECT TRAJECTORIES
runner_state, update_steps = update_runner_state
def _env_step(runner_state, unused):
train_state, env_state, last_obs, last_done, rng = runner_state

# SELECT ACTION
rng, _rng = jax.random.split(rng)
avail_actions = jax.vmap(env.get_legal_moves)(env_state.env_state)
avail_actions = jax.lax.stop_gradient(
batchify(avail_actions, env.agents, config["NUM_ACTORS"])
)
obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
ac_in = (obs_batch[np.newaxis, :], last_done[np.newaxis, :], avail_actions[np.newaxis, :])
pi, value = network.apply(train_state.params, ac_in)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)

# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))(
rng_step, env_state, env_act
)
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
done_batch,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
log_prob.squeeze(),
obs_batch,
info,
avail_actions
)
runner_state = (train_state, env_state, obsv, done_batch, rng)
return runner_state, transition

runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config["NUM_STEPS"]
)

# CALCULATE ADVANTAGE
train_state, env_state, last_obs, last_done, rng = runner_state
last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
avail_actions = jnp.ones(
(config["NUM_ACTORS"], env.action_space(env.agents[0]).n)
)
ac_in = (last_obs_batch[np.newaxis, :], last_done[np.newaxis, :], avail_actions)
_, last_val = network.apply(train_state.params, ac_in)
last_val = last_val.squeeze()

def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
gae = (
delta
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
)
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)

# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info

def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params,
(traj_batch.obs, traj_batch.done, traj_batch.avail_actions))
log_prob = pi.log_prob(traj_batch.action)

# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["CLIP_EPS"],
1.0 + config["CLIP_EPS"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()

total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss

train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)

batch = (traj_batch, advantages.squeeze(), targets.squeeze())
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree_map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
[x.shape[0], config["NUM_MINIBATCHES"], -1]
+ list(x.shape[2:]),
),
1,
0,
),
shuffled_batch,
)

train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss

update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]

def callback(metric):
wandb.log(
{
"returns": metric["returned_episode_returns"][-1, :].mean(),
"env_step": metric["update_steps"]
* config["NUM_ENVS"]
* config["NUM_STEPS"],
}
)
metric["update_steps"] = update_steps
jax.experimental.io_callback(callback, None, metric)
update_steps = update_steps + 1
runner_state = (train_state, env_state, last_obs, last_done, rng)
return (runner_state, update_steps), None

rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, jnp.zeros((config["NUM_ACTORS"]), dtype=bool), _rng)
runner_state, _ = jax.lax.scan(
_update_step, (runner_state, 0), None, config["NUM_UPDATES"]
)
return {"runner_state": runner_state}

return train

@hydra.main(version_base=None, config_path="config", config_name="ippo_ff_hanabi")
def main(config):
config = OmegaConf.to_container(config)

wandb.init(
entity=config["ENTITY"],
project=config["PROJECT"],
tags=["IPPO", "FF", config["ENV_NAME"]],
config=config,
mode=config["WANDB_MODE"],
)

rng = jax.random.PRNGKey(50)
train_jit = jax.jit(make_train(config), device=jax.devices()[0])
out = train_jit(rng)


if __name__ == "__main__":
main()
'''results = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)
jnp.save('hanabi_results', results)
plt.plot(results)
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.savefig(f'IPPO_{config["ENV_NAME"]}.png')'''
Loading

0 comments on commit 3b5b12b

Please sign in to comment.