diff --git a/README.md b/README.md
index be6377e0..11ab3651 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
| IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
+| TransfQMIX | [Peper](https://www.southampton.ac.uk/~eg/AAMAS2023/pdfs/p1679.pdf) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
Installation 🧗
diff --git a/baselines/QLearning/README.md b/baselines/QLearning/README.md
index c9128b9d..47fb3afa 100644
--- a/baselines/QLearning/README.md
+++ b/baselines/QLearning/README.md
@@ -5,6 +5,7 @@ Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
+* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems)
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)
The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)
@@ -26,12 +27,12 @@ pip install -r requirements/requirements-qlearning.txt
- Hanabi
```
-## 🔎 Implementation Details
+## ⚙️ Implementation Details
General features:
- Agents are controlled by a single RNN architecture.
-- You can choose whether to share parameters between agents or not.
+- You can choose whether to share parameters between agents or not (not available on TransfQMix).
- Works also with non-homogeneous agents (different observation/action spaces).
- Experience replay is a simple buffer with uniform sampling.
- Uses Double Q-Learning with a target agent network (hard-updated).
@@ -60,8 +61,8 @@ python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
-# QMix with hanabi
-python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi
+# VDN with hanabi
+python baselines/QLearning/vdn.py +alg=qlearn_hanabi +env=hanabi
# QMix against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
# TransfQMix
@@ -75,44 +76,6 @@ Notice that with Hydra, you can modify parameters on the go in this way:
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_spread alg.PARAMETERS_SHARING=False
```
-It is often useful to run these scripts manually in a notebook or in another script.
-
-```python
-from jaxmarl import make
-from baselines.QLearning.qmix import make_train
-
-env = make("MPE_simple_spread_v3")
-
-config = {
- "NUM_ENVS": 8,
- "BUFFER_SIZE": 5000,
- "BUFFER_BATCH_SIZE": 32,
- "TOTAL_TIMESTEPS": 2050000,
- "AGENT_HIDDEN_DIM": 64,
- "AGENT_INIT_SCALE": 2.0,
- "PARAMETERS_SHARING": True,
- "EPSILON_START": 1.0,
- "EPSILON_FINISH": 0.05,
- "EPSILON_ANNEAL_TIME": 100000,
- "MIXER_EMBEDDING_DIM": 32,
- "MIXER_HYPERNET_HIDDEN_DIM": 64,
- "MIXER_INIT_SCALE": 0.00001,
- "MAX_GRAD_NORM": 25,
- "TARGET_UPDATE_INTERVAL": 200,
- "LR": 0.005,
- "LR_LINEAR_DECAY": True,
- "EPS_ADAM": 0.001,
- "WEIGHT_DECAY_ADAM": 0.00001,
- "TD_LAMBDA_LOSS": True,
- "TD_LAMBDA": 0.6,
- "GAMMA": 0.9,
- "VERBOSE": False,
- "WANDB_ONLINE_REPORT": False,
- "NUM_TEST_EPISODES": 32,
- "TEST_INTERVAL": 50000,
-}
-
-rng = jax.random.PRNGKey(42)
-train_vjit = jax.jit(make_train(config, env))
-outs = train_vjit(rng)
-```
\ No newline at end of file
+## 🎯 Hyperparameter tuning
+
+Please refer to the ```tune``` function in the [transf_qmix.py](transf_qmix.py) script for an example of hyperparameter tuning using WANDB.
\ No newline at end of file
diff --git a/baselines/QLearning/transf_qmix.py b/baselines/QLearning/transf_qmix.py
index 3d131bb9..5d0e0f3e 100644
--- a/baselines/QLearning/transf_qmix.py
+++ b/baselines/QLearning/transf_qmix.py
@@ -6,7 +6,7 @@
- It's added the possibility to perform $n$ training updates of the network at each update step.
Currently supports only MPE_spread and SMAX. Remember that to use the transformers in your environment you need
-to reshape the observations and states into matrices. See: jaxmarl.wrappers.transformers
+to reshape the observations and states to matrices. See: jaxmarl.wrappers.transformers
"""
import os
@@ -1026,9 +1026,6 @@ def wrapped_make_train():
train_vjit = jax.jit(jax.vmap(make_train(config["alg"], env)))
outs = jax.block_until_ready(train_vjit(rngs))
- n_updates = (
- default_config["alg"]["TOTAL_TIMESTEPS"] // default_config["alg"]["NUM_STEPS"] // default_config["alg"]["NUM_ENVS"]
- )*default_config["NUM_SEEDS"] # 2 seeds will log double, 3 seeds will log triple and so on
sweep_config = {
'method': 'bayes',
'metric': {
@@ -1036,11 +1033,11 @@ def wrapped_make_train():
'goal': 'maximize',
},
'parameters':{
- 'LR':{'values':[0.01, 0.005, 0.001, 0.0005]},
- 'LR_EXP_DECAY_RATE':{'values':[0.02, 0.002, 0.0002, 0.00002]},
+ 'LR':{'values':[0.005, 0.001, 0.0005]},
'EPS_ADAM':{'values':[0.0001, 0.0000001, 0.0000000001]},
- 'NUM_ENVS':{'values':[16, 32]},
- 'N_MINI_UPDATES':{'values':[2, 4, 8]},
+ 'SCALE_INPUTS':{'values':[True, False]},
+ 'NUM_ENVS':{'values':[8, 16]},
+ 'N_MINI_UPDATES':{'values':[1, 2, 4]},
},
}