Skip to content

Commit

Permalink
hanabi ready to be merged
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 22, 2024
1 parent bf641e0 commit 1fb473a
Show file tree
Hide file tree
Showing 22 changed files with 10,348 additions and 1,038 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__/
*.ipynb
*.DS_Store
.vscode/
.ipynb_checkpoints/
docker/*
*.pickle
results/
Expand Down
1 change: 1 addition & 0 deletions baselines/IPPO/ippo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _env_step(runner_state, unused):
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
env_act = jax.tree_map(lambda x: x.squeeze(), env_act)

# STEP ENV
rng, _rng = jax.random.split(rng)
Expand Down
19 changes: 8 additions & 11 deletions baselines/QLearning/transf_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- It's added the possibility to perform $n$ training updates of the network at each update step.
Currently supports only MPE_spread and SMAX. Remember that to use the transformers in your environment you need
to reshape the observations and states into matrices. See: jaxmarl.wrappers.transformers
to reshape the observations and states to matrices. See: jaxmarl.wrappers.transformers
"""

import os
Expand Down Expand Up @@ -928,7 +928,7 @@ def callback(timestep, val):
return train


def signle_run(config):
def single_run(config):
"""Perform a single run with multiple parallel seeds in one env."""
config = OmegaConf.to_container(config)

Expand Down Expand Up @@ -1026,21 +1026,18 @@ def wrapped_make_train():
train_vjit = jax.jit(jax.vmap(make_train(config["alg"], env)))
outs = jax.block_until_ready(train_vjit(rngs))

n_updates = (
default_config["alg"]["TOTAL_TIMESTEPS"] // default_config["alg"]["NUM_STEPS"] // default_config["alg"]["NUM_ENVS"]
)*default_config["NUM_SEEDS"] # 2 seeds will log double, 3 seeds will log triple and so on
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'test_returns',
'goal': 'maximize',
},
'parameters':{
'LR':{'values':[0.01, 0.005, 0.001, 0.0005]},
'LR_EXP_DECAY_RATE':{'values':[0.02, 0.002, 0.0002, 0.00002]},
'LR':{'values':[0.005, 0.001, 0.0005]},
'EPS_ADAM':{'values':[0.0001, 0.0000001, 0.0000000001]},
'NUM_ENVS':{'values':[16, 32]},
'N_MINI_UPDATES':{'values':[2, 4, 8]},
'SCALE_INPUTS':{'values':[True, False]},
'NUM_ENVS':{'values':[8, 16]},
'N_MINI_UPDATES':{'values':[1, 2, 4]},
},
}

Expand All @@ -1051,7 +1048,7 @@ def wrapped_make_train():
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(config):
#tune(config) # uncomment to run hypertuning
signle_run(config)
single_run(config)

if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 1fb473a

Please sign in to comment.