diff --git a/Dockerfile b/Dockerfile index b3299296..c4cb6440 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN apt-get update && \ apt-get install -y tmux #jaxmarl from source if needed, all the requirements -RUN pip install -e . +RUN pip install -e .[algs,dev] USER ${MYUSER} diff --git a/README.md b/README.md index 22f14690..908332ba 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ ## Multi-Agent Reinforcement Learning in JAX +🎉 Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver! + JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. For more details, take a look at our [blog post](https://blog.foersterlab.com/jaxmarl/) or our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb), which walks through the basic usage. @@ -72,7 +74,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca

Installation 🧗

-**Environments** - Before installing, ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your hardware accelerator. The JaxMARL environments can be installed directly from PyPi: +**Environments** - Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi: ``` pip install jaxmarl @@ -84,12 +86,14 @@ pip install jaxmarl ``` git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL ``` -2. The requirements for IPPO & MAPPO can be installed with: +2. Requirements can be installed with: ``` - pip install -e . + pip install -e .[algs] export PYTHONPATH=./JaxMARL:$PYTHONPATH ``` +**Development** - If you would like to run our test suite, install the additonal dependencies with `pip install -e .[dev]`. +

Quick Start 🚀

We take inspiration from the [PettingZoo](https://github.com/Farama-Foundation/PettingZoo) and [Gymnax](https://github.com/RobertTLange/gymnax) interfaces. You can try out training an agent in our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb). Further introduction scripts can be found [here](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/tutorials). @@ -151,6 +155,7 @@ JAX-native algorithms: - [Mava](https://github.com/instadeepai/Mava): JAX implementations of IPPO and MAPPO, two popular MARL algorithms. - [PureJaxRL](https://github.com/luchris429/purejaxrl): JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training. - [Minimax](https://github.com/facebookresearch/minimax/): JAX implementations of autocurricula baselines for RL. +- [JaxIRL](https://github.com/FLAIROx/jaxirl?tab=readme-ov-file): JAX implementation of algorithms for inverse reinforcement learning. JAX-native environments: - [Gymnax](https://github.com/RobertTLange/gymnax): Implementations of classic RL tasks including classic control, bsuite and MinAtar. @@ -158,3 +163,4 @@ JAX-native environments: - [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi. - [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks. - [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid. +- [Craftax](https://github.com/MichaelTMatthews/Craftax): (Crafter + NetHack) in JAX. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 34c9ba4d..37468a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ ] [project.optional-dependencies] -alg = [ +algs = [ "optax", "distrax", "safetensors",