Skip to content

Commit

Permalink
shaq to one file, readme updated, wandb disabled by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Rutherford committed Dec 3, 2023
1 parent 71a4ea7 commit 8ccde30
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 10 deletions.
13 changes: 10 additions & 3 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# QLearning Baselines

*Pure Jax implementation of **IQL** (Independent Q-Learners), **VDN** (Value Decomposition Network), and **QMix**. These implementations follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase.*

Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)

```
⚠️ The implementations were tested with Python 3.9 and Jax 0.4.11.
Expand Down Expand Up @@ -50,9 +57,9 @@ If you have cloned JaxMARL and you are in the repository root, you can run the a
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
# VDN with MPE spread
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMix with SMAX
# QMIX with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMix against pretrained agents
# QMIX against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
```

Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# wandb params
"ENTITY": ""
"PROJECT": "jaxMARL"
"WANDB_MODE": "online"
"WANDB_MODE": "disabled"

# where to save the params (if None, will not save)
"SAVE_PATH": "baselines/QLearning/checkpoints"
94 changes: 88 additions & 6 deletions baselines/QLearning/shaq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import jax
import jax.numpy as jnp
from functools import partial
from typing import NamedTuple, Dict, Union
import numpy as np

import chex
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
Expand All @@ -33,11 +34,44 @@
import wandb
import hydra
from omegaconf import OmegaConf
import flashbax as fbx
from safetensors.flax import save_file
from flax.traverse_util import flatten_dict


from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from baselines.QLearning.utils import CTRolloutManager, EpsilonGreedy, Transition, UniformBuffer, ScannedRNN, save_params
from jaxmarl.wrappers.baselines import CTRolloutManager

class ScannedRNN(nn.Module):

@partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@nn.compact
def __call__(self, carry, x):
"""Applies the module."""
rnn_state = carry
ins, resets = x
hidden_size = ins.shape[-1]
rnn_state = jnp.where(
resets[:, np.newaxis],
self.initialize_carry(hidden_size, *ins.shape[:-1]),
rnn_state,
)
new_rnn_state, y = nn.GRUCell(hidden_size)(rnn_state, ins)
return new_rnn_state, y

@staticmethod
def initialize_carry(hidden_size, *batch_size):
# Use a dummy key since the default state init fn is just zeros.
return nn.GRUCell(hidden_size, parent=None).initialize_carry(
jax.random.PRNGKey(0), (*batch_size, hidden_size)
)

class AgentRNN(nn.Module):
# homogenous agent for parameters sharing, assumes all agents have same obs and action dim
Expand Down Expand Up @@ -230,6 +264,41 @@ def __call__(self, q_vals, states, max_filter, target, manual_alpha_estimates=No
# if the agent with the max-action then alpha = 1. Otherwise, the agent will use the learned alpha
return jnp.sum((alpha_estimates * non_max_filter + max_filter) * q_vals, axis=0)

class EpsilonGreedy:
"""Epsilon Greedy action selection"""

def __init__(self, start_e: float, end_e: float, duration: int):
self.start_e = start_e
self.end_e = end_e
self.duration = duration
self.slope = (end_e - start_e) / duration

@partial(jax.jit, static_argnums=0)
def get_epsilon(self, t: int):
e = self.slope*t + self.start_e
return jnp.clip(e, self.end_e)

@partial(jax.jit, static_argnums=0)
def choose_actions(self, q_vals: dict, t: int, rng: chex.PRNGKey):

def explore(q, eps, key):
key_a, key_e = jax.random.split(key, 2) # a key for sampling random actions and one for picking
greedy_actions = jnp.argmax(q, axis=-1) # get the greedy actions
random_actions = jax.random.randint(key_a, shape=greedy_actions.shape, minval=0, maxval=q.shape[-1]) # sample random actions
pick_random = jax.random.uniform(key_e, greedy_actions.shape)<eps # pick which actions should be random
chosed_actions = jnp.where(pick_random, random_actions, greedy_actions)
return chosed_actions

eps = self.get_epsilon(t)
keys = dict(zip(q_vals.keys(), jax.random.split(rng, len(q_vals)))) # get a key for each agent
chosen_actions = jax.tree_map(lambda q, k: explore(q, eps, k), q_vals, keys)
return chosen_actions

class Transition(NamedTuple):
obs: dict
actions: dict
rewards: dict
dones: dict

def make_train(config, env):

Expand Down Expand Up @@ -260,8 +329,15 @@ def _env_sample_step(env_state, unused):
_env_sample_step, env_state, None, config["NUM_STEPS"]
)
sample_traj_unbatched = jax.tree_map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim
buffer = UniformBuffer(parallel_envs=config["NUM_ENVS"], batch_size=config["BUFFER_BATCH_SIZE"], max_size=config["BUFFER_SIZE"])
buffer_state = buffer.reset(sample_traj_unbatched)
buffer = fbx.make_flat_buffer(
max_length=config['BUFFER_SIZE'],
min_length=config['BUFFER_BATCH_SIZE'],
sample_batch_size=config['BUFFER_BATCH_SIZE'],
add_sequences=True,
add_batch_size=None,
)
buffer_state = buffer.init(sample_traj_unbatched)


# INIT NETWORK
# init agent
Expand Down Expand Up @@ -410,7 +486,8 @@ def _env_step(step_state, unused):
)

# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_state = buffer.add(buffer_state, traj_batch)
buffer_traj_batch = jax.tree_util.tree_map(lambda x:jnp.swapaxes(x, 0, 1), traj_batch) # put the batch size (num envs) in first axis
buffer_state = buffer.add(buffer_state, buffer_traj_batch)

# LEARN PHASE
def q_of_action(q, u):
Expand Down Expand Up @@ -505,7 +582,7 @@ def _td_lambda_target(ret, values):

# sample a batched trajectory from the buffer and set the time step dim in first axis
rng, _rng = jax.random.split(rng)
_, learn_traj = buffer.sample(buffer_state, _rng) # (batch_size, max_time_steps, ...)
learn_traj = buffer.sample(buffer_state, _rng).experience.first # (batch_size, max_time_steps, ...)
learn_traj = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), learn_traj) # (max_time_steps, batch_size, ...)
if config["PARAMETERS_SHARING"]:
init_hs = ScannedRNN.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents)*config["BUFFER_BATCH_SIZE"]) # (n_agents*batch_size, hs_size)
Expand Down Expand Up @@ -706,6 +783,11 @@ def main(config):

# save params
if config['SAVE_PATH'] is not None:

def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
flattened_dict = flatten_dict(params, sep=',')
save_file(flattened_dict, filename)

model_state = outs['runner_state'][0]
params = jax.tree_map(lambda x: x[0], model_state.params) # save only params of the firt run
save_dir = os.path.join(config['SAVE_PATH'], env_name)
Expand Down

0 comments on commit 8ccde30

Please sign in to comment.