Skip to content

Commit

Permalink
moving to nvidia dockerfile and python 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 22, 2024
1 parent 8059e1b commit b70f597
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 72 deletions.
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
wandb/
tmp/
outputs/
results/
models/
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
# os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge]
# For Apple Silicon: https://github.com/actions/runner-images/issues/8439
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9']
python-version: ['3.10']
defaults:
run:
shell: bash
Expand All @@ -28,8 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -e '.[dev]'
pip install -e
- name: Run pytest
run: pytest tests
50 changes: 5 additions & 45 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,62 +1,22 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

# install python
ARG DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.10
#setting language and locale
ENV LANG="C.UTF-8" LC_ALL="C.UTF-8"


RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
software-properties-common \
build-essential \
curl \
ffmpeg \
git \
htop \
vim \
nano \
rsync \
wget \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-get update && apt-get install -y -qq python${PYTHON_VERSION} \
python${PYTHON_VERSION}-dev \
python${PYTHON_VERSION}-distutils

# Set python aliases
RUN update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python get-pip.py
FROM nvcr.io/nvidia/jax:23.10-py3

# default workdir
WORKDIR /home/workdir
COPY . .

#jaxmarl from source if needed, all the requirements
RUN pip install --ignore-installed -e '.[qlearning, dev]'

# install jax from to enable cuda
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

RUN pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118

RUN pip install -e .

#disabling preallocation
RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
#safety measures
RUN export XLA_PYTHON_CLIENT_MEM_FRACTION=0.25
RUN export TF_FORCE_GPU_ALLOW_GROWTH=true

#for jupyter
EXPOSE 9999
# if you want jupyter
RUN pip install pip install jupyterlab

#for secrets and debug
ENV WANDB_API_KEY=""
ENV WANDB_ENTITY=""
RUN git config --global --add safe.directory /home/workdir

CMD ["/bin/bash"]
RUN git config --global --add safe.directory /home/workdir
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ pip install jaxmarl
pip install -e .
export PYTHONPATH=./JaxMARL:$PYTHONPATH
```
3. If you would also like to run the Q-learning algorithms, Python 3.9 is required along with additional dependencies:
```
pip install -e '.[qlearning]'
```
**Test Scripts** - To run our test scripts, some additional dependencies are required (for comparisons against existing implementations), these can be installed with:
```
pip install -r requirements/requirements-dev.txt
```
<h2 name="start" id="start">Quick Start 🚀 </h2>
Expand Down
12 changes: 2 additions & 10 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,15 @@ Pure JAX implementations of:

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)

```
⚠️ The implementations were tested with Python 3.9 and Jax 0.4.11.
With Jax 0.4.13, you could experience a degradation of performance.
```

We use [`flashbax`](https://github.com/instadeepai/flashbax) to provide our replay buffers, this requires Python 3.9 and the dependency can be installed with:
```
pip install -r requirements/requirements-qlearning.txt
```

```
❗The implementations were tested in the following environments:
- MPE
- SMAX
- Hanabi
```

WIP for Hanabi and Overcooked.

## ⚙️ Implementation Details

General features:
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ include = ['jaxmarl*']
[tool.setuptools.dynamic]
version = {attr = "jaxmarl.__version__"}
dependencies = {file = ["requirements/requirements.txt"]}
optional-dependencies = {dev = { file = ["requirements/requirements-dev.txt"] }, qlearning = { file = ["requirements/requirements-qlearning.txt"] }}

[project]
name = "jaxmarl"
Expand All @@ -18,14 +17,13 @@ description = "Multi-Agent Reinforcement Learning with JAX"
authors = [
{name = "Foerster Lab for AI Research", email = "[email protected]"},
]
dynamic = ["version", "dependencies", "optional-dependencies"]
dynamic = ["version", "dependencies"]
license = {file = "LICENSE"}
requires-python = ">=3.8"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Operating System :: OS Independent",
Expand Down

0 comments on commit b70f597

Please sign in to comment.