Skip to content

Commit

Permalink
Merge pull request #71 from FLAIROx/hanabi_obl_aligned
Browse files Browse the repository at this point in the history
Corrected Hanabi, new Dockerfile, python 3.10 and other fixes
  • Loading branch information
mttga authored Mar 22, 2024
2 parents 60b9fd2 + 8615058 commit 8f17f22
Show file tree
Hide file tree
Showing 33 changed files with 11,849 additions and 555 deletions.
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
wandb/
tmp/
outputs/
results/
models/
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
# os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge]
# For Apple Silicon: https://github.com/actions/runner-images/issues/8439
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9']
python-version: ['3.10']
defaults:
run:
shell: bash
Expand All @@ -28,8 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -e '.[dev]'
pip install -e .
- name: Run pytest
run: pytest tests
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ __pycache__/
*.ipynb
*.DS_Store
.vscode/
.ipynb_checkpoints/
docker/*
*.pickle
results/
docs/
tmp/
*-checkpoint.py
wandb/
outputs/
outputs/
models/
47 changes: 5 additions & 42 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,59 +1,22 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

# install python
ARG DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.10
#setting language and locale
ENV LANG="C.UTF-8" LC_ALL="C.UTF-8"


RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
software-properties-common \
build-essential \
curl \
ffmpeg \
git \
htop \
vim \
nano \
rsync \
wget \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-get update && apt-get install -y -qq python${PYTHON_VERSION} \
python${PYTHON_VERSION}-dev \
python${PYTHON_VERSION}-distutils

# Set python aliases
RUN update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python get-pip.py
FROM nvcr.io/nvidia/jax:23.10-py3

# default workdir
WORKDIR /home/workdir
COPY . .

#jaxmarl from source if needed, all the requirements
RUN pip install --ignore-installed -e '.[qlearning, dev]'

# install jax from to enable cuda
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install -e .

#disabling preallocation
RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
#safety measures
RUN export XLA_PYTHON_CLIENT_MEM_FRACTION=0.25
RUN export TF_FORCE_GPU_ALLOW_GROWTH=true

#for jupyter
EXPOSE 9999
# if you want jupyter
RUN pip install pip install jupyterlab

#for secrets and debug
ENV WANDB_API_KEY=""
ENV WANDB_ENTITY=""
RUN git config --global --add safe.directory /home/workdir

CMD ["/bin/bash"]
RUN git config --global --add safe.directory /home/workdir
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ pip install jaxmarl
pip install -e .
export PYTHONPATH=./JaxMARL:$PYTHONPATH
```
3. If you would also like to run the Q-learning algorithms, Python 3.9 is required along with additional dependencies:
```
pip install -e '.[qlearning]'
```
**Test Scripts** - To run our test scripts, some additional dependencies are required (for comparisons against existing implementations), these can be installed with:
```
pip install -r requirements/requirements-dev.txt
```
<h2 name="start" id="start">Quick Start 🚀 </h2>
Expand Down
2 changes: 1 addition & 1 deletion baselines/IPPO/config/ippo_ff_hanabi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# WandB Params
"WANDB_MODE": "disabled"
"ENTITY": ""
"PROJECT": "jaxmarl-hanabi"
"PROJECT": ""
1 change: 1 addition & 0 deletions baselines/IPPO/ippo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _env_step(runner_state, unused):
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
env_act = jax.tree_map(lambda x: x.squeeze(), env_act)

# STEP ENV
rng, _rng = jax.random.split(rng)
Expand Down
12 changes: 2 additions & 10 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,15 @@ Pure JAX implementations of:

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)

```
⚠️ The implementations were tested with Python 3.9 and Jax 0.4.11.
With Jax 0.4.13, you could experience a degradation of performance.
```

We use [`flashbax`](https://github.com/instadeepai/flashbax) to provide our replay buffers, this requires Python 3.9 and the dependency can be installed with:
```
pip install -r requirements/requirements-qlearning.txt
```

```
❗The implementations were tested in the following environments:
- MPE
- SMAX
- Hanabi
```

WIP for Hanabi and Overcooked.

## ⚙️ Implementation Details

General features:
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/config/alg/qmix_smax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"LR_LINEAR_DECAY": False
"EPS_ADAM": 0.00001
"WEIGHT_DECAY_ADAM": 0.000001
"TD_LAMBDA_LOSS": True
"TD_LAMBDA_LOSS": False
"TD_LAMBDA": 0.6
"GAMMA": 0.99
"VERBOSE": False
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/config/env/smax.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"ENV_NAME": "HeuristicEnemySMAX"
"MAP_NAME": "5m_vs_6m"
"MAP_NAME": "2s3z"
"ENV_KWARGS":
"see_enemy_actions": True
"walls_cause_death": True
Expand Down
2 changes: 1 addition & 1 deletion baselines/QLearning/transf_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,4 +1051,4 @@ def main(config):
single_run(config)

if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion jaxmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .switch_riddle import SwitchRiddle
from .overcooked import Overcooked, overcooked_layouts
from .mabrax import Ant, Humanoid, Hopper, Walker2d, HalfCheetah
from .hanabi import HanabiGame
from .hanabi import Hanabi
from .storm import InTheGrid, InTheGrid_2p
from .coin_game import CoinGame

182 changes: 165 additions & 17 deletions jaxmarl/environments/hanabi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,176 @@

This directory contains a MARL environment for the cooperative card game, Hanabi, implemented in JAX. It is inspired by the popular [Hanabi Learning Environment (HLE)](https://arxiv.org/pdf/1902.00506.pdf), but intended to be simpler to integrate and run with the growing ecosystem of JAX implemented RL research pipelines.

#### A note on tuning
The performance of IPPO on Hanabi, as implemented in this repo, is currently marginally lower than the reported [SoTA result for IPPO](https://arxiv.org/pdf/2103.01955.pdf). They run a very extensive hyperparameter sweep and conducting similarly comprehensive tuning of the JAX implemnation is on the near-term agenda.

## Action Space
Hanabi is a turn-based game. The current player can choose to discard or play any of the cards in their hand, or hint a colour or rank to any one of their teammates.

## Observation Space
The observations closely follow the featurization in the HLE.

Each observation is comprised of:
- card knowledge (binary encoding of implicit card knowledge); size `(hand_size * num_colors * num_ranks)`
- color and rank hints (binary encoding of explicit hints made about player's hand); size `(hand_size * (num_colors + num_ranks)`
- fireworks (thermometer encoded); size `(num_colors * num_ranks)`
- info tokens (thermometer encoded); size `max_info_tokens`
- life tokens (thermometer encoded); size `max_life_tokens`
- last moves (one-hot encoding of most recent move of each player); size `(num_players * num_moves)`
- current player (one-hot encoding); size `num_players`
- discard pile (one-hot encodings of discarded cards); size `(num_cards_of_color * num_colors * num_colors * num_ranks)
- remaining deck size (thermometer encoded); size `(num_cards_of_color * num_colors)`
The observations closely follow the featurization in the HLE. Each observation is comprised of 658 features:

* **Hands (127)**: information about the visible hands.
* other player hand: 125
* card 0: 25,
* card 1: 25
* card 2: 25
* card 3: 25
* card 4: 25
* Hands missing card: 2 (one-hot)

* **Board (76)**: encoding of the public information visible in the board.
* Deck: 40, thermometer
* Fireworks: 25, one-hot
* Info Tokens: 8, thermometer
* ife Tokens: 3, thermometer

* **Discards (50)**: encoding of the cards in the discard pile.
* Colour R: 10 bits for each card
* Colour Y: 10 bits for each card
* Colour G: 10 bits for each card
* Colour W: 10 bits for each card
* Colour B: 10 bits for each card

* **Last Action (55)**: encoding of the last move of the previous player.
* Acting player index, relative to yourself: 2, one-hot
* MoveType: 4, one-hot
* Target player index, relative to acting player: 2, one-hot
* Color revealed: 5, one-hot
* Rank revealed: 5, one-hot
* Reveal outcome 5 bits, each bit is 1 if the card was hinted at
* Position played/discarded: 5, one-hot
* Card played/discarded 25, one-hot
* Card played scored: 1
* Card played added info token: 1

* **V0 belief (350)**: trivially-computed probability of being a specific car (given the played-discarded cards and the hints given), for each card of each player.
* Possible Card (for each card): 25 (* 10)
* Colour hinted (for each card): 5 (* 10)
* Rank hinted (for each card): 5 (* 10)

## Pretrained Models

We make available to use some pretrained models. For example you can use a jax conversion of the original R2D2 OBL model in this way:

1. Download the models from Hugginface: ```git clone https://huggingface.co/mttga/obl-r2d2-flax``` (ensure to have git lfs installed). You can also use the script: ```bash jaxmarl/environments/hanabi/models/download_r2d2_obl.sh```
2. Load the parameters, import the agent wrapper and use it with JaxMarl Hanabi:

```python
!git clone https://huggingface.co/mttga/obl-r2d2-flax
import jax
from jax import numpy as jnp
from jaxmarl import make
from jaxmarl.wrappers.baselines import load_params
from jaxmarl.environments.hanabi.pretrained import OBLAgentR2D2

weight_file = "jaxmarl/environments/hanabi/pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors"
params = load_params(weight_file)

agent = OBLAgentR2D2()
agent_carry = agent.initialize_carry(jax.random.PRNGKey(0), batch_dims=(2,))

rng = jax.random.PRNGKey(0)
env = make('hanabi')
obs, env_state = env.reset(rng)
env.render(env_state)

batchify = lambda x: jnp.stack([x[agent] for agent in env.agents])
unbatchify = lambda x: {agent:x[i] for i, agent in enumerate(env.agents)}

agent_input = (
batchify(obs),
batchify(env.get_legal_moves(env_state))
)
agent_carry, actions = agent.greedy_act(params, agent_carry, agent_input)
actions = unbatchify(actions)

obs, env_state, rewards, done, info = env.step(rng, env_state, actions)

print('actions:', {agent:env.action_encoding[int(a)] for agent, a in actions.items()})
env.render(env_state)
```

## Rendering

You can render the full environment state:

```python
obs, env_state = env.reset(rng)
env.render(env_state)

Turn: 0

Score: 0
Information: 8
Lives: 3
Deck: 40
Discards:
Fireworks:
Actor 0 Hand:<-- current player
0 W3 || XX|RYGWB12345
1 G5 || XX|RYGWB12345
2 G4 || XX|RYGWB12345
3 G1 || XX|RYGWB12345
4 Y2 || XX|RYGWB12345
Actor 1 Hand:
0 R3 || XX|RYGWB12345
1 B1 || XX|RYGWB12345
2 G1 || XX|RYGWB12345
3 R4 || XX|RYGWB12345
4 W4 || XX|RYGWB12345
```

Or you can render the partial observation of the current agent:

```python
obs, new_env_state, rewards, dones, infos = env.step_env(rng, env_state, actions)
obs_s = env.get_obs_str(new_env_state, env_state, a, include_belief=True, best_belief=5)
print(obs_s)

Turn: 1

Score: 0
Information available: 7
Lives available: 3
Deck remaining cards: 40
Discards:
Fireworks:
Other Hand:
0 Card: W3, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
1 Card: G5, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
2 Card: G4, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
3 Card: G1, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
4 Card: Y2, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
Your Hand:
0 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
1 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200]
2 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200]
3 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
4 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
Last action: H1
Cards afected: [1 2]
Legal Actions: ['D0', 'D1', 'D2', 'D3', 'D4', 'P0', 'P1', 'P2', 'P3', 'P4', 'HY', 'HG', 'HW', 'H1', 'H2', 'H3', 'H4', 'H5']
```

## Manual Game

You can test the environment and your models by using the ```manual_game.py``` script in this folder. It allows to control one or two agents with the keyboard and one or two agents with a pretrained model (an obl model by default). For example, to play with an obl pretrained model:

```
python manual_game.py \
--player0 "manual" \
--player1 "obl" \
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
```

Or to look an obl model playing with itself:

```
python manual_game.py \
--player0 "obl" \
--player1 "obl" \
--weight0 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
```

## Citation
The environment was orginally described in the following work:
Expand All @@ -32,6 +183,3 @@ The environment was orginally described in the following work:
year={2019}
}
```

## To Do
- [ ] Algorithm tuning
Loading

0 comments on commit 8f17f22

Please sign in to comment.