From 959ed6c8d664a62dcf02a88a7e1d699a962b0568 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 10 Oct 2024 20:41:46 +0100 Subject: [PATCH 1/8] reference jaxmarl spaces rather than gymnax --- jaxmarl/environments/hanabi/hanabi.py | 2 +- jaxmarl/environments/mabrax/mabrax_env.py | 2 +- jaxmarl/environments/mpe/simple_adversary.py | 2 +- jaxmarl/environments/mpe/simple_crypto.py | 2 +- jaxmarl/environments/mpe/simple_push.py | 2 +- jaxmarl/environments/mpe/simple_reference.py | 2 +- jaxmarl/environments/mpe/simple_speaker_listener.py | 2 +- jaxmarl/environments/mpe/simple_spread.py | 2 +- jaxmarl/environments/mpe/simple_tag.py | 2 +- jaxmarl/environments/mpe/simple_world_comm.py | 3 +-- jaxmarl/environments/spaces.py | 1 + 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py index aaf2e799..bc47d88c 100644 --- a/jaxmarl/environments/hanabi/hanabi.py +++ b/jaxmarl/environments/hanabi/hanabi.py @@ -9,7 +9,7 @@ import chex from typing import Tuple, Dict from functools import partial -from gymnax.environments.spaces import Discrete +from jaxmarl.environments.spaces import Discrete from .hanabi_game import HanabiGame, State diff --git a/jaxmarl/environments/mabrax/mabrax_env.py b/jaxmarl/environments/mabrax/mabrax_env.py index ce170927..f4edb783 100644 --- a/jaxmarl/environments/mabrax/mabrax_env.py +++ b/jaxmarl/environments/mabrax/mabrax_env.py @@ -1,7 +1,7 @@ from typing import Dict, Literal, Optional, Tuple import chex from jaxmarl.environments.multi_agent_env import MultiAgentEnv -from gymnax.environments import spaces +from jaxmarl.environments import spaces from brax import envs import jax import jax.numpy as jnp diff --git a/jaxmarl/environments/mpe/simple_adversary.py b/jaxmarl/environments/mpe/simple_adversary.py index 5273602b..f2706629 100644 --- a/jaxmarl/environments/mpe/simple_adversary.py +++ b/jaxmarl/environments/mpe/simple_adversary.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import State, SimpleMPE from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box class SimpleAdversaryMPE(SimpleMPE): diff --git a/jaxmarl/environments/mpe/simple_crypto.py b/jaxmarl/environments/mpe/simple_crypto.py index 1ce4f09a..d1da2d5d 100644 --- a/jaxmarl/environments/mpe/simple_crypto.py +++ b/jaxmarl/environments/mpe/simple_crypto.py @@ -6,7 +6,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete SPEAKER = "alice_0" LISTENER = "bob_0" diff --git a/jaxmarl/environments/mpe/simple_push.py b/jaxmarl/environments/mpe/simple_push.py index bbfa37ce..72d23ea0 100644 --- a/jaxmarl/environments/mpe/simple_push.py +++ b/jaxmarl/environments/mpe/simple_push.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box # Obstacle Colours COLOUR_1 = jnp.array([0.1, 0.9, 0.1]) diff --git a/jaxmarl/environments/mpe/simple_reference.py b/jaxmarl/environments/mpe/simple_reference.py index b86314bb..ae70e482 100644 --- a/jaxmarl/environments/mpe/simple_reference.py +++ b/jaxmarl/environments/mpe/simple_reference.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete # Obstacle Colours OBS_COLOUR = [(191, 64, 64), (64, 191, 64), (64, 64, 191)] diff --git a/jaxmarl/environments/mpe/simple_speaker_listener.py b/jaxmarl/environments/mpe/simple_speaker_listener.py index 8d1ee78e..c3ddeacb 100644 --- a/jaxmarl/environments/mpe/simple_speaker_listener.py +++ b/jaxmarl/environments/mpe/simple_speaker_listener.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete SPEAKER = "speaker_0" LISTENER = "listener_0" diff --git a/jaxmarl/environments/mpe/simple_spread.py b/jaxmarl/environments/mpe/simple_spread.py index ebabe61d..222c5818 100644 --- a/jaxmarl/environments/mpe/simple_spread.py +++ b/jaxmarl/environments/mpe/simple_spread.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box class SimpleSpreadMPE(SimpleMPE): diff --git a/jaxmarl/environments/mpe/simple_tag.py b/jaxmarl/environments/mpe/simple_tag.py index bf9c0869..813b032e 100644 --- a/jaxmarl/environments/mpe/simple_tag.py +++ b/jaxmarl/environments/mpe/simple_tag.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box from jaxmarl.environments.mpe.default_params import * diff --git a/jaxmarl/environments/mpe/simple_world_comm.py b/jaxmarl/environments/mpe/simple_world_comm.py index 4343e212..e22a2643 100644 --- a/jaxmarl/environments/mpe/simple_world_comm.py +++ b/jaxmarl/environments/mpe/simple_world_comm.py @@ -11,8 +11,7 @@ OBS_COLOUR, ) from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete - +from jaxmarl.environments.spaces import Box, Discrete # NOTE food and forests are part of world.landmarks diff --git a/jaxmarl/environments/spaces.py b/jaxmarl/environments/spaces.py index 4320bd46..8f97d6c9 100644 --- a/jaxmarl/environments/spaces.py +++ b/jaxmarl/environments/spaces.py @@ -1,3 +1,4 @@ +""" Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """ from typing import Tuple, Union, Sequence from collections import OrderedDict import chex From 3d246d1170aeceb9f9f19d44d0bf15187b0fe205 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 10 Oct 2024 20:42:11 +0100 Subject: [PATCH 2/8] seperate requirements --- pyproject.toml | 32 ++++++++++++++++++++++++++++++-- requirements/requirements.txt | 26 -------------------------- 2 files changed, 30 insertions(+), 28 deletions(-) delete mode 100644 requirements/requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 3f5bfbba..34c9ba4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ include = ['jaxmarl*'] [tool.setuptools.dynamic] version = {attr = "jaxmarl.__version__"} -dependencies = {file = ["requirements/requirements.txt"]} [project] name = "jaxmarl" @@ -17,7 +16,7 @@ description = "Multi-Agent Reinforcement Learning with JAX" authors = [ {name = "Foerster Lab for AI Research", email = "arutherford@robots.ox.ac.uk"}, ] -dynamic = ["version", "dependencies"] +dynamic = ["version"] license = {file = "LICENSE"} requires-python = ">=3.10" classifiers = [ @@ -31,6 +30,35 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", ] +dependencies = [ + "jax>=0.4.16.0,<=0.4.25", + "jaxlib>=0.4.16.0,<=0.4.25", + "flax", + "chex", + "brax==0.10.3", + "mujoco==3.1.3", + "matplotlib", + "pillow", + "scipy<=1.12", + "gymnax", +] + +[project.optional-dependencies] +alg = [ + "optax", + "distrax", + "safetensors", + "flashbax==0.1.0", + "wandb", + "hydra-core>=1.3.2", + "omegaconf>=2.3.0", + "pettingzoo>=1.24.3", + "tqdm>=4.66.0", +] +dev = [ + "pytest", + "pygame", +] [project.urls] "Homepage" = "https://github.com/FLAIROx/JaxMARL" diff --git a/requirements/requirements.txt b/requirements/requirements.txt deleted file mode 100644 index 2e0fd4b8..00000000 --- a/requirements/requirements.txt +++ /dev/null @@ -1,26 +0,0 @@ -# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax>=0.4.16.0,<=0.4.25 -jaxlib>=0.4.16.0,<=0.4.25 -flax==0.7.4 -chex==0.1.84 -optax==0.1.7 -dotmap==1.3.30 -evosax==0.1.5 -distrax==0.1.5 -brax==0.10.3 -mujoco==3.1.3 -gymnax==0.0.6 -safetensors==0.4.2 -flashbax==0.1.0 -# less sensitive libs -wandb -pytest -pygame -numpy>=1.26.1 -hydra-core>=1.3.2 -omegaconf>=2.3.0 -matplotlib>=3.8.3 -pillow>=10.2.0 -pettingzoo>=1.24.3 -tqdm>=4.66.0 -scipy<=1.12 From cba255a95c379229426ea5bdc91e8d3e53300cdc Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 10 Oct 2024 21:06:15 +0100 Subject: [PATCH 3/8] update readme, dockerfile --- Dockerfile | 2 +- README.md | 12 +++++++++--- pyproject.toml | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index b3299296..c4cb6440 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN apt-get update && \ apt-get install -y tmux #jaxmarl from source if needed, all the requirements -RUN pip install -e . +RUN pip install -e .[algs,dev] USER ${MYUSER} diff --git a/README.md b/README.md index 22f14690..908332ba 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ ## Multi-Agent Reinforcement Learning in JAX +🎉 Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver! + JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. For more details, take a look at our [blog post](https://blog.foersterlab.com/jaxmarl/) or our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb), which walks through the basic usage. @@ -72,7 +74,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca

Installation 🧗

-**Environments** - Before installing, ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your hardware accelerator. The JaxMARL environments can be installed directly from PyPi: +**Environments** - Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi: ``` pip install jaxmarl @@ -84,12 +86,14 @@ pip install jaxmarl ``` git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL ``` -2. The requirements for IPPO & MAPPO can be installed with: +2. Requirements can be installed with: ``` - pip install -e . + pip install -e .[algs] export PYTHONPATH=./JaxMARL:$PYTHONPATH ``` +**Development** - If you would like to run our test suite, install the additonal dependencies with `pip install -e .[dev]`. +

Quick Start 🚀

We take inspiration from the [PettingZoo](https://github.com/Farama-Foundation/PettingZoo) and [Gymnax](https://github.com/RobertTLange/gymnax) interfaces. You can try out training an agent in our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb). Further introduction scripts can be found [here](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/tutorials). @@ -151,6 +155,7 @@ JAX-native algorithms: - [Mava](https://github.com/instadeepai/Mava): JAX implementations of IPPO and MAPPO, two popular MARL algorithms. - [PureJaxRL](https://github.com/luchris429/purejaxrl): JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training. - [Minimax](https://github.com/facebookresearch/minimax/): JAX implementations of autocurricula baselines for RL. +- [JaxIRL](https://github.com/FLAIROx/jaxirl?tab=readme-ov-file): JAX implementation of algorithms for inverse reinforcement learning. JAX-native environments: - [Gymnax](https://github.com/RobertTLange/gymnax): Implementations of classic RL tasks including classic control, bsuite and MinAtar. @@ -158,3 +163,4 @@ JAX-native environments: - [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi. - [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks. - [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid. +- [Craftax](https://github.com/MichaelTMatthews/Craftax): (Crafter + NetHack) in JAX. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 34c9ba4d..37468a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ ] [project.optional-dependencies] -alg = [ +algs = [ "optax", "distrax", "safetensors", From c802dccdd4afd15476c0c13fab03b6067d68731b Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 10 Oct 2024 21:09:06 +0100 Subject: [PATCH 4/8] increment version number --- jaxmarl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxmarl/__init__.py b/jaxmarl/__init__.py index de0658fe..962e720b 100644 --- a/jaxmarl/__init__.py +++ b/jaxmarl/__init__.py @@ -1,4 +1,4 @@ from .registration import make, registered_envs __all__ = ["make", "registered_envs"] -__version__ = "0.0.5" +__version__ = "0.0.6" From 84a4f3dbafa1a119814a946884082f6865c5d37c Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 10 Oct 2024 21:09:57 +0100 Subject: [PATCH 5/8] bold update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 908332ba..1dc4d0e7 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ ## Multi-Agent Reinforcement Learning in JAX -🎉 Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver! +🎉 **Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver!** JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. From c900f936ab1c2f509ed4a078f14f444e14841b19 Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 11 Oct 2024 17:25:12 +0100 Subject: [PATCH 6/8] small updates --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1dc4d0e7..aa1268c9 100644 --- a/README.md +++ b/README.md @@ -86,13 +86,15 @@ pip install jaxmarl ``` git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL ``` -2. Requirements can be installed with: +2. Install requirements: ``` pip install -e .[algs] export PYTHONPATH=./JaxMARL:$PYTHONPATH ``` +3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below. -**Development** - If you would like to run our test suite, install the additonal dependencies with `pip install -e .[dev]`. +**Development** - If you would like to run our test suite, install the additonal dependencies with: + `pip install -e .[dev]`, after cloning the repository.

Quick Start 🚀

From 9c310f4d6e28e96106caa3cf6a5871c3d246f0b5 Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 11 Oct 2024 17:28:16 +0100 Subject: [PATCH 7/8] final spaces correction --- jaxmarl/environments/mpe/simple.py | 2 +- jaxmarl/environments/mpe/simple_facmac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxmarl/environments/mpe/simple.py b/jaxmarl/environments/mpe/simple.py index 59a32c7d..426e3659 100644 --- a/jaxmarl/environments/mpe/simple.py +++ b/jaxmarl/environments/mpe/simple.py @@ -10,7 +10,7 @@ from jaxmarl.environments.multi_agent_env import MultiAgentEnv from jaxmarl.environments.mpe.default_params import * import chex -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete from flax import struct from typing import Tuple, Optional, Dict from functools import partial diff --git a/jaxmarl/environments/mpe/simple_facmac.py b/jaxmarl/environments/mpe/simple_facmac.py index e7b18970..1c655ecf 100644 --- a/jaxmarl/environments/mpe/simple_facmac.py +++ b/jaxmarl/environments/mpe/simple_facmac.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from functools import partial from jaxmarl.environments.mpe.simple import State, SimpleMPE -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box from jaxmarl.environments.mpe.default_params import * From 1c81f4163e0ed608ac6609b6a466f05cf069b0b4 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 14 Oct 2024 12:37:49 +0100 Subject: [PATCH 8/8] add safetensors as dep for wrappers file --- pyproject.toml | 2 +- tests/hanabi/test_hanabi.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 37468a2d..a7c5a865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "jax>=0.4.16.0,<=0.4.25", "jaxlib>=0.4.16.0,<=0.4.25", "flax", + "safetensors", "chex", "brax==0.10.3", "mujoco==3.1.3", @@ -47,7 +48,6 @@ dependencies = [ algs = [ "optax", "distrax", - "safetensors", "flashbax==0.1.0", "wandb", "hydra-core>=1.3.2", diff --git a/tests/hanabi/test_hanabi.py b/tests/hanabi/test_hanabi.py index 4c76e1a5..b6fcfb62 100644 --- a/tests/hanabi/test_hanabi.py +++ b/tests/hanabi/test_hanabi.py @@ -4,7 +4,6 @@ import jax from jax import numpy as jnp from jaxmarl import make -from jaxmarl.wrappers.baselines import LogWrapper env = make("hanabi") dir_path = os.path.dirname(os.path.realpath(__file__))