Skip to content

Latest commit

 

History

History
68 lines (55 loc) · 2.45 KB

README.md

File metadata and controls

68 lines (55 loc) · 2.45 KB

Goal-Conditioned Reinforcement Learning (Jax/Flax/Optax)

This repository contains a collection of goal-conditioned reinforcement learning algorithms. It is compatible with the latest Gymnasium API and uses very recent version of jax, flax and optax. We support multiprocessing via mpi4jax like the deprecated OpenAI baselines.

Supported Algorithms

All algorithms make use of Hindsight Experience Replay (HER paper link)

Installation

  • git clone https://github.com/frankroeder/goal_conditioned_rl.git
  • pip users: pip install -r requirements.txt
  • conda users: conda create --file= conda_env.yaml
  • libraries: apt install libopenmpi-dev

Jax CUDA Support

https://github.com/google/jax#installation To install on a machine with an Nvidia GPU, run

# install packages
pip install -r requirements.txt
# remove jaxlib and install cuda version of necessary
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Training

Single process

# SAC
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.0
# DDPG
python train.py n_epochs=10 agent=ddpg env_name=FetchPush-v2 hindsight=her
# DroQ
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.01

Multiple processes

mpirun -np 4 python -u train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her

Enjoy your trained agent

python demo.py --demo_path <path to the trial folder>
# or
python demo.py --wandb_url <wandb trial url>

Results

... more results will follow

References