Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into hanabi_obl_aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 22, 2024
2 parents 1fb473a + 60b9fd2 commit fcbae7a
Show file tree
Hide file tree
Showing 21 changed files with 796 additions and 1,291 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
| IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| TransfQMIX | [Peper](https://www.southampton.ac.uk/~eg/AAMAS2023/pdfs/p1679.pdf) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |

<h2 name="install" id="install">Installation 🧗 </h2>
Expand Down
55 changes: 10 additions & 45 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems)
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)
Expand All @@ -26,12 +27,12 @@ pip install -r requirements/requirements-qlearning.txt
- Hanabi
```

## 🔎 Implementation Details
## ⚙️ Implementation Details

General features:

- Agents are controlled by a single RNN architecture.
- You can choose whether to share parameters between agents or not.
- You can choose whether to share parameters between agents or not (not available on TransfQMix).
- Works also with non-homogeneous agents (different observation/action spaces).
- Experience replay is a simple buffer with uniform sampling.
- Uses Double Q-Learning with a target agent network (hard-updated).
Expand Down Expand Up @@ -60,8 +61,8 @@ python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMix with hanabi
python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi
# VDN with hanabi
python baselines/QLearning/vdn.py +alg=qlearn_hanabi +env=hanabi
# QMix against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
# TransfQMix
Expand All @@ -75,44 +76,8 @@ Notice that with Hydra, you can modify parameters on the go in this way:
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_spread alg.PARAMETERS_SHARING=False
```

It is often useful to run these scripts manually in a notebook or in another script.

```python
from jaxmarl import make
from baselines.QLearning.qmix import make_train

env = make("MPE_simple_spread_v3")

config = {
"NUM_ENVS": 8,
"BUFFER_SIZE": 5000,
"BUFFER_BATCH_SIZE": 32,
"TOTAL_TIMESTEPS": 2050000,
"AGENT_HIDDEN_DIM": 64,
"AGENT_INIT_SCALE": 2.0,
"PARAMETERS_SHARING": True,
"EPSILON_START": 1.0,
"EPSILON_FINISH": 0.05,
"EPSILON_ANNEAL_TIME": 100000,
"MIXER_EMBEDDING_DIM": 32,
"MIXER_HYPERNET_HIDDEN_DIM": 64,
"MIXER_INIT_SCALE": 0.00001,
"MAX_GRAD_NORM": 25,
"TARGET_UPDATE_INTERVAL": 200,
"LR": 0.005,
"LR_LINEAR_DECAY": True,
"EPS_ADAM": 0.001,
"WEIGHT_DECAY_ADAM": 0.00001,
"TD_LAMBDA_LOSS": True,
"TD_LAMBDA": 0.6,
"GAMMA": 0.9,
"VERBOSE": False,
"WANDB_ONLINE_REPORT": False,
"NUM_TEST_EPISODES": 32,
"TEST_INTERVAL": 50000,
}

rng = jax.random.PRNGKey(42)
train_vjit = jax.jit(make_train(config, env))
outs = train_vjit(rng)
```
**❗Note on Transformers**: TransfQMix currently supports only MPE_Spread and SMAX. You will need to wrap the observation vectors into matrices to use transformers in other environments. See: ```jaxmarl.wrappers.transformers```

## 🎯 Hyperparameter tuning

Please refer to the ```tune``` function in the [transf_qmix.py](transf_qmix.py) script for an example of hyperparameter tuning using WANDB.
Binary file not shown.
8 changes: 4 additions & 4 deletions baselines/QLearning/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# experiment params
"NUM_SEEDS": 1
"NUM_SEEDS": 2
"SEED": 0

# wandb params
"ENTITY": "mttga"
"PROJECT": "transf_qmix_tuning"
"WANDB_MODE": "online"
"ENTITY": ""
"PROJECT": ""
"WANDB_MODE": "disabled"

# where to save the params (if None, will not save)
"SAVE_PATH": "baselines/QLearning/checkpoints"
4 changes: 2 additions & 2 deletions baselines/QLearning/config/env/mpe_spread.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"ENV_NAME": "MPE_simple_spread_v3"
"ENV_KWARGS":
"num_agents": 5
"num_landmarks": 5
"num_agents": 3
"num_landmarks": 3
104 changes: 71 additions & 33 deletions jaxmarl/environments/smax/heuristic_enemy_smax_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
from jaxmarl.environments.smax.smax_env import SMAX
from jaxmarl.environments.smax.smax_env import State as SMAXState
from jaxmarl.environments.smax.heuristic_enemy import (
create_heuristic_policy,
get_heuristic_policy_initial_state,
Expand Down Expand Up @@ -59,30 +61,74 @@ def get_enemy_actions(self, key, enemy_policy_state, enemy_obs):
def get_enemy_policy_initial_state(self, key):
raise NotImplementedError

@partial(jax.jit, static_argnums=(0,))
def step_env(self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]):
@partial(jax.jit, static_argnums=(0, 4))
def step_env(
self,
key: chex.PRNGKey,
state: State,
actions: Dict[str, chex.Array],
get_state_sequence=False,
):
jaxmarl_state = state.state
obs = self._env.get_obs(jaxmarl_state)
enemy_obs = jnp.array([obs[agent] for agent in self.enemy_agents])
enemy_obs = self._env.get_obs_unit_list(jaxmarl_state)
enemy_obs = jnp.array([enemy_obs[agent] for agent in self.enemy_agents])
key, action_key = jax.random.split(key)
enemy_actions, enemy_policy_state = self.get_enemy_actions(
action_key, state.enemy_policy_state, enemy_obs
)
enemy_actions = jnp.array([enemy_actions[i] for i in self.enemy_agents])
actions = jnp.array([actions[i] for i in self.agents])
enemy_movement_actions, enemy_attack_actions = (
self._env._decode_discrete_actions(enemy_actions)
)
if self._env.action_type == "continuous":
cont_actions = jnp.zeros((len(self.all_agents), 4))
cont_actions = cont_actions.at[: self.num_allies].set(actions)
key, action_key = jax.random.split(key)
ally_movement_actions, ally_attack_actions = (
self._env._decode_continuous_actions(
action_key, jaxmarl_state, cont_actions
)
)
ally_movement_actions = ally_movement_actions[: self.num_allies]
ally_attack_actions = ally_attack_actions[: self.num_allies]
else:
ally_movement_actions, ally_attack_actions = (
self._env._decode_discrete_actions(actions)
)

actions = {k: v.squeeze() for k, v in actions.items()}
actions = {**enemy_actions, **actions}
obs, jaxmarl_state, rewards, dones, infos = self._env.step_env(
key, jaxmarl_state, actions
movement_actions = jnp.concat(
[ally_movement_actions, enemy_movement_actions], axis=0
)
new_obs = {agent: obs[agent] for agent in self.agents}
new_obs["world_state"] = obs["world_state"]
rewards = {agent: rewards[agent] for agent in self.agents}
all_done = dones["__all__"]
dones = {agent: dones[agent] for agent in self.agents}
dones["__all__"] = all_done
attack_actions = jnp.concat([ally_attack_actions, enemy_attack_actions], axis=0)

if not get_state_sequence:
obs, jaxmarl_state, rewards, dones, infos = self._env.step_env_no_decode(
key,
jaxmarl_state,
(movement_actions, attack_actions),
get_state_sequence=get_state_sequence,
)
new_obs = {agent: obs[agent] for agent in self.agents}
new_obs["world_state"] = obs["world_state"]
rewards = {agent: rewards[agent] for agent in self.agents}
all_done = dones["__all__"]
dones = {agent: dones[agent] for agent in self.agents}
dones["__all__"] = all_done

state = state.replace(enemy_policy_state=enemy_policy_state, state=jaxmarl_state)
return new_obs, state, rewards, dones, infos
state = state.replace(
enemy_policy_state=enemy_policy_state, state=jaxmarl_state
)
return new_obs, state, rewards, dones, infos
else:
states = self._env.step_env_no_decode(
key,
jaxmarl_state,
(movement_actions, attack_actions),
get_state_sequence=get_state_sequence,
)
return states

@partial(jax.jit, static_argnums=(0,))
def get_avail_actions(self, state: State):
Expand Down Expand Up @@ -110,28 +156,20 @@ def expand_state_seq(self, state_seq):
# it's not exposed to the user so we can't ask them to store it. Not a problem
# for now but will have to get creative in the future potentially.
for key, state, actions in state_seq:
agents = self.all_agents
# There is a split in the step function of MultiAgentEnv
# We call split here so that the action key is the same.
key, _ = jax.random.split(key)
key, key_action = jax.random.split(key)
obs = self.get_all_unit_obs(state)
obs = jnp.array([obs[agent] for agent in self.enemy_agents])
enemy_actions, _ = self.get_enemy_actions(
key_action, state.enemy_policy_state, obs
states = self.step_env(key, state, actions, get_state_sequence=True)
states = list(map(SMAXState, *dataclasses.astuple(states)))
viz_actions = {
agent: states[-1].prev_attack_actions[i]
for i, agent in enumerate(self.all_agents)
}

expanded_state_seq.append((key, state.state, viz_actions))
expanded_state_seq.extend(
zip([key] * len(states), states, [viz_actions] * len(states))
)
actions = {k: v.squeeze() for k, v in actions.items()}
actions = {**enemy_actions, **actions}
for _ in range(self.world_steps_per_env_step):
expanded_state_seq.append((key, state.state, actions))
world_actions = jnp.array([actions[i] for i in agents])
key, step_key = jax.random.split(key)
_state = state.state
_state = self._env._world_step(step_key, _state, world_actions)
_state = self._env._kill_agents_touching_walls(_state)
_state = self._env._update_dead_agents(_state)
_state = self._env._push_units_away(_state)
state = state.replace(state=_state)
state = state.replace(
state=state.state.replace(terminal=self.is_terminal(state))
)
Expand Down
Loading

0 comments on commit fcbae7a

Please sign in to comment.