Skip to content

Commit

Permalink
update readmes
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Feb 12, 2024
1 parent c8f3dbe commit 5b94ef7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 53 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
| IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| TransfQMIX | [Peper](https://www.southampton.ac.uk/~eg/AAMAS2023/pdfs/p1679.pdf) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |

<h2 name="install" id="install">Installation 🧗 </h2>
Expand Down
53 changes: 8 additions & 45 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems)
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)
Expand All @@ -26,12 +27,12 @@ pip install -r requirements/requirements-qlearning.txt
- Hanabi
```

## 🔎 Implementation Details
## ⚙️ Implementation Details

General features:

- Agents are controlled by a single RNN architecture.
- You can choose whether to share parameters between agents or not.
- You can choose whether to share parameters between agents or not (not available on TransfQMix).
- Works also with non-homogeneous agents (different observation/action spaces).
- Experience replay is a simple buffer with uniform sampling.
- Uses Double Q-Learning with a target agent network (hard-updated).
Expand Down Expand Up @@ -60,8 +61,8 @@ python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMix with hanabi
python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi
# VDN with hanabi
python baselines/QLearning/vdn.py +alg=qlearn_hanabi +env=hanabi
# QMix against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
# TransfQMix
Expand All @@ -75,44 +76,6 @@ Notice that with Hydra, you can modify parameters on the go in this way:
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_spread alg.PARAMETERS_SHARING=False
```

It is often useful to run these scripts manually in a notebook or in another script.

```python
from jaxmarl import make
from baselines.QLearning.qmix import make_train

env = make("MPE_simple_spread_v3")

config = {
"NUM_ENVS": 8,
"BUFFER_SIZE": 5000,
"BUFFER_BATCH_SIZE": 32,
"TOTAL_TIMESTEPS": 2050000,
"AGENT_HIDDEN_DIM": 64,
"AGENT_INIT_SCALE": 2.0,
"PARAMETERS_SHARING": True,
"EPSILON_START": 1.0,
"EPSILON_FINISH": 0.05,
"EPSILON_ANNEAL_TIME": 100000,
"MIXER_EMBEDDING_DIM": 32,
"MIXER_HYPERNET_HIDDEN_DIM": 64,
"MIXER_INIT_SCALE": 0.00001,
"MAX_GRAD_NORM": 25,
"TARGET_UPDATE_INTERVAL": 200,
"LR": 0.005,
"LR_LINEAR_DECAY": True,
"EPS_ADAM": 0.001,
"WEIGHT_DECAY_ADAM": 0.00001,
"TD_LAMBDA_LOSS": True,
"TD_LAMBDA": 0.6,
"GAMMA": 0.9,
"VERBOSE": False,
"WANDB_ONLINE_REPORT": False,
"NUM_TEST_EPISODES": 32,
"TEST_INTERVAL": 50000,
}

rng = jax.random.PRNGKey(42)
train_vjit = jax.jit(make_train(config, env))
outs = train_vjit(rng)
```
## 🎯 Hyperparameter tuning

Please refer to the ```tune``` function in the [transf_qmix.py](transf_qmix.py) script for an example of hyperparameter tuning using WANDB.
13 changes: 5 additions & 8 deletions baselines/QLearning/transf_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- It's added the possibility to perform $n$ training updates of the network at each update step.
Currently supports only MPE_spread and SMAX. Remember that to use the transformers in your environment you need
to reshape the observations and states into matrices. See: jaxmarl.wrappers.transformers
to reshape the observations and states to matrices. See: jaxmarl.wrappers.transformers
"""

import os
Expand Down Expand Up @@ -1026,21 +1026,18 @@ def wrapped_make_train():
train_vjit = jax.jit(jax.vmap(make_train(config["alg"], env)))
outs = jax.block_until_ready(train_vjit(rngs))

n_updates = (
default_config["alg"]["TOTAL_TIMESTEPS"] // default_config["alg"]["NUM_STEPS"] // default_config["alg"]["NUM_ENVS"]
)*default_config["NUM_SEEDS"] # 2 seeds will log double, 3 seeds will log triple and so on
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'test_returns',
'goal': 'maximize',
},
'parameters':{
'LR':{'values':[0.01, 0.005, 0.001, 0.0005]},
'LR_EXP_DECAY_RATE':{'values':[0.02, 0.002, 0.0002, 0.00002]},
'LR':{'values':[0.005, 0.001, 0.0005]},
'EPS_ADAM':{'values':[0.0001, 0.0000001, 0.0000000001]},
'NUM_ENVS':{'values':[16, 32]},
'N_MINI_UPDATES':{'values':[2, 4, 8]},
'SCALE_INPUTS':{'values':[True, False]},
'NUM_ENVS':{'values':[8, 16]},
'N_MINI_UPDATES':{'values':[1, 2, 4]},
},
}

Expand Down

0 comments on commit 5b94ef7

Please sign in to comment.