diff --git a/README.md b/README.md index be6377e0..11ab3651 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca | IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) | | VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) | | QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) | +| TransfQMIX | [Peper](https://www.southampton.ac.uk/~eg/AAMAS2023/pdfs/p1679.pdf) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) | | SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |

Installation 🧗

diff --git a/baselines/QLearning/README.md b/baselines/QLearning/README.md index c9128b9d..47fb3afa 100644 --- a/baselines/QLearning/README.md +++ b/baselines/QLearning/README.md @@ -5,6 +5,7 @@ Pure JAX implementations of: * IQL (Independent Q-Learners) * VDN (Value Decomposition Network) * QMIX +* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems) * SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning) The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning) @@ -26,12 +27,12 @@ pip install -r requirements/requirements-qlearning.txt - Hanabi ``` -## 🔎 Implementation Details +## ⚙️ Implementation Details General features: - Agents are controlled by a single RNN architecture. -- You can choose whether to share parameters between agents or not. +- You can choose whether to share parameters between agents or not (not available on TransfQMix). - Works also with non-homogeneous agents (different observation/action spaces). - Experience replay is a simple buffer with uniform sampling. - Uses Double Q-Learning with a target agent network (hard-updated). @@ -60,8 +61,8 @@ python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread # QMix with SMAX python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax -# QMix with hanabi -python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi +# VDN with hanabi +python baselines/QLearning/vdn.py +alg=qlearn_hanabi +env=hanabi # QMix against pretrained agents python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained # TransfQMix @@ -75,44 +76,6 @@ Notice that with Hydra, you can modify parameters on the go in this way: python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_spread alg.PARAMETERS_SHARING=False ``` -It is often useful to run these scripts manually in a notebook or in another script. - -```python -from jaxmarl import make -from baselines.QLearning.qmix import make_train - -env = make("MPE_simple_spread_v3") - -config = { - "NUM_ENVS": 8, - "BUFFER_SIZE": 5000, - "BUFFER_BATCH_SIZE": 32, - "TOTAL_TIMESTEPS": 2050000, - "AGENT_HIDDEN_DIM": 64, - "AGENT_INIT_SCALE": 2.0, - "PARAMETERS_SHARING": True, - "EPSILON_START": 1.0, - "EPSILON_FINISH": 0.05, - "EPSILON_ANNEAL_TIME": 100000, - "MIXER_EMBEDDING_DIM": 32, - "MIXER_HYPERNET_HIDDEN_DIM": 64, - "MIXER_INIT_SCALE": 0.00001, - "MAX_GRAD_NORM": 25, - "TARGET_UPDATE_INTERVAL": 200, - "LR": 0.005, - "LR_LINEAR_DECAY": True, - "EPS_ADAM": 0.001, - "WEIGHT_DECAY_ADAM": 0.00001, - "TD_LAMBDA_LOSS": True, - "TD_LAMBDA": 0.6, - "GAMMA": 0.9, - "VERBOSE": False, - "WANDB_ONLINE_REPORT": False, - "NUM_TEST_EPISODES": 32, - "TEST_INTERVAL": 50000, -} - -rng = jax.random.PRNGKey(42) -train_vjit = jax.jit(make_train(config, env)) -outs = train_vjit(rng) -``` \ No newline at end of file +## 🎯 Hyperparameter tuning + +Please refer to the ```tune``` function in the [transf_qmix.py](transf_qmix.py) script for an example of hyperparameter tuning using WANDB. \ No newline at end of file diff --git a/baselines/QLearning/transf_qmix.py b/baselines/QLearning/transf_qmix.py index 3d131bb9..5d0e0f3e 100644 --- a/baselines/QLearning/transf_qmix.py +++ b/baselines/QLearning/transf_qmix.py @@ -6,7 +6,7 @@ - It's added the possibility to perform $n$ training updates of the network at each update step. Currently supports only MPE_spread and SMAX. Remember that to use the transformers in your environment you need -to reshape the observations and states into matrices. See: jaxmarl.wrappers.transformers +to reshape the observations and states to matrices. See: jaxmarl.wrappers.transformers """ import os @@ -1026,9 +1026,6 @@ def wrapped_make_train(): train_vjit = jax.jit(jax.vmap(make_train(config["alg"], env))) outs = jax.block_until_ready(train_vjit(rngs)) - n_updates = ( - default_config["alg"]["TOTAL_TIMESTEPS"] // default_config["alg"]["NUM_STEPS"] // default_config["alg"]["NUM_ENVS"] - )*default_config["NUM_SEEDS"] # 2 seeds will log double, 3 seeds will log triple and so on sweep_config = { 'method': 'bayes', 'metric': { @@ -1036,11 +1033,11 @@ def wrapped_make_train(): 'goal': 'maximize', }, 'parameters':{ - 'LR':{'values':[0.01, 0.005, 0.001, 0.0005]}, - 'LR_EXP_DECAY_RATE':{'values':[0.02, 0.002, 0.0002, 0.00002]}, + 'LR':{'values':[0.005, 0.001, 0.0005]}, 'EPS_ADAM':{'values':[0.0001, 0.0000001, 0.0000000001]}, - 'NUM_ENVS':{'values':[16, 32]}, - 'N_MINI_UPDATES':{'values':[2, 4, 8]}, + 'SCALE_INPUTS':{'values':[True, False]}, + 'NUM_ENVS':{'values':[8, 16]}, + 'N_MINI_UPDATES':{'values':[1, 2, 4]}, }, }