Skip to content

Commit 7d46afa

Browse files
committed
MED commit to github
0 parents  commit 7d46afa

File tree

154 files changed

+12311
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

154 files changed

+12311
-0
lines changed

.pre-commit-config.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# See https://pre-commit.com for more information
2+
# See https://pre-commit.com/hooks.html for more hooks
3+
repos:
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
5+
rev: v4.4.0
6+
hooks:
7+
- id: trailing-whitespace
8+
- id: check-added-large-files
9+
- id: check-symlinks
10+
- id: end-of-file-fixer
11+
- id: check-yaml
12+
- id: check-toml
13+
- id: check-ast
14+
- id: check-added-large-files
15+
- id: check-merge-conflict
16+
- id: detect-private-key
17+
# - id: debug-statements
18+
# - id: double-quote-string-fixer
19+
- repo: https://github.com/psf/black
20+
rev: 23.7.0
21+
hooks:
22+
- id: black
23+
- id: black-jupyter
24+
- repo: https://github.com/PyCQA/isort
25+
rev: 5.12.0
26+
hooks:
27+
- id: isort
28+
args: ["--profile", "black", "--filter-files"]

README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Multi-Expert Distillation for Few-Shot Coordination (Student Abstract)
2+
3+
This repository contains the implementation of Multi-Expert Distillation (MED), based on PyTorch.
4+
5+
## 1. Getting started
6+
7+
Use the install script to install the python environment:
8+
9+
```shell
10+
bash install.sh
11+
conda activate med
12+
```
13+
14+
## 2. Run an experiment
15+
All the experiments can be run with the unified entrance file `examples/train.py` with customized arguments.
16+
17+
### LIPO
18+
The repository consists of a re-implementation of [LIPO]([https://sites.google.com/view/iclr-lipo-2023).
19+
For generating a population in Girdworld MoveBox or Overcooked, enter the `examples` folder and run the following command:
20+
```bash
21+
python train.py --algo lipo --env gridworld --task MoveBox --map multi_exits --exp_name test --use_wandb True --pop_size 8 --horizon 50 --n_iter 500 --eval_interval 10 --n_sp_ts 5000 --n_xp_ts 5000 --eval_interval 10
22+
```
23+
```bash
24+
python train.py --algo lipo --env overcooked --map_name full_divider_salad_multi_ingred --exp_name test --use_wandb True --pop_size 8 --horizon 100 --n_iter 1000 --n_sp_ts 5000 --n_xp_ts 5000 --eval_interval 10
25+
```
26+
The results and models can be found in the `examples/results` folder.
27+
### MED
28+
To run MED, the population model files should be placed in the `harl/runners/generalist_runners/models` folder. Users should make sure the file is named properly.
29+
For running MED, enter the `examples` folder and run the following commands:
30+
```bash
31+
python train.py --algo med --env matrix_game --exp_name performance --t_max 30000 --n_episodes 3 --use_wandb True
32+
```
33+
```bash
34+
python train.py --algo med --env gridworld --task MoveBox --map multi_exits --exp_name performance --t_max 2000000 --horizon 50 --n_episodes 2 --use_wandb True
35+
```
36+
```bash
37+
python train.py --algo med --env overcooked --map_name full_divider_salad_multi_ingred --exp_name performance --t_max 7500000 --horizon 100 --n_episodes 2 --use_wandb True
38+
```
39+
Training scripts are also provided in the `examples` folder.

examples/train.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Train an algorithm."""
2+
import argparse
3+
import json
4+
5+
from harl.utils.configs_tools import get_defaults_yaml_args, update_args
6+
7+
8+
def main():
9+
"""Main function."""
10+
parser = argparse.ArgumentParser(
11+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
12+
)
13+
parser.add_argument(
14+
"--algo",
15+
type=str,
16+
default="med",
17+
choices=[
18+
"med",
19+
"lipo",
20+
],
21+
help="Algorithm name. Choose from: med, lipo.",
22+
)
23+
parser.add_argument(
24+
"--env",
25+
type=str,
26+
default="matrix_game",
27+
choices=[
28+
"matrix_game",
29+
"gridworld",
30+
"overcooked",
31+
],
32+
help="Environment name. Choose from: matrix_game, gridworld, overcooked.",
33+
)
34+
parser.add_argument(
35+
"--exp_name", type=str, default="installtest", help="Experiment name."
36+
)
37+
parser.add_argument(
38+
"--load_config",
39+
type=str,
40+
default="",
41+
help="If set, load existing experiment config file instead of reading from yaml config file.",
42+
)
43+
args, unparsed_args = parser.parse_known_args()
44+
45+
def process(arg):
46+
try:
47+
return eval(arg)
48+
except:
49+
return arg
50+
51+
keys = [k[2:] for k in unparsed_args[0::2]] # remove -- from argument
52+
values = [process(v) for v in unparsed_args[1::2]]
53+
unparsed_dict = {k: v for k, v in zip(keys, values)}
54+
args = vars(args) # convert to dict
55+
if args["load_config"] != "": # load config from existing config file
56+
with open(args["load_config"], encoding="utf-8") as file:
57+
all_config = json.load(file)
58+
args["algo"] = all_config["main_args"]["algo"]
59+
args["env"] = all_config["main_args"]["env"]
60+
args["exp_name"] = all_config["main_args"]["exp_name"]
61+
algo_args = all_config["algo_args"]
62+
env_args = all_config["env_args"]
63+
else: # load config from corresponding yaml file
64+
algo_args, env_args = get_defaults_yaml_args(args["algo"], args["env"])
65+
update_args(unparsed_dict, algo_args, env_args) # update args from command line
66+
67+
# start training
68+
from harl.runners import RUNNER_REGISTRY
69+
70+
runner = RUNNER_REGISTRY[args["algo"]](args, algo_args, env_args)
71+
runner.run()
72+
runner.close()
73+
74+
75+
if __name__ == "__main__":
76+
main()

examples/train_mg.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
for seed in 111 222 333 444 555
2+
do
3+
python train.py --algo med --env matrix_game --exp_name performance --use_wandb True --seed $seed --t_max 30000
4+
done

examples/train_movebox.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
for seed in 111 222 333 444 555
2+
do
3+
python train.py --algo med --env gridworld --task MoveBox --map multi_exits --exp_name performance --t_max 2000000 --horizon 50 --n_episodes 2 --use_wandb True --seed $seed
4+
done

examples/train_overcooked.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
for seed in 111 222 333 444 555
2+
do
3+
python train.py --algo med --env overcooked --map_name full_divider_salad_multi_ingred --exp_name performance --t_max 7500000 --horizon 100 --n_episodes 2 --use_wandb True --seed $seed
4+
done

harl/__init__.py

Whitespace-only changes.

harl/algorithms/__init__.py

Whitespace-only changes.

harl/algorithms/actors/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Algorithm registry."""
2+
# lipo
3+
from harl.algorithms.actors.incompact_mappo_z import IncompatMAPPOZ
4+
from harl.algorithms.actors.med_gpt import GPTAgent
5+
6+
ALGO_REGISTRY = {
7+
# population
8+
"lipo": IncompatMAPPOZ,
9+
# generalist
10+
"med": GPTAgent,
11+
}

0 commit comments

Comments
 (0)