Skip to content

Commit

Permalink
Support custom rlpd data transform (#48)
Browse files Browse the repository at this point in the history
* try drq with mani env

Signed-off-by: youliang <[email protected]>

* clean up and remove redundant drq example

Signed-off-by: youliang <[email protected]>

---------

Signed-off-by: youliang <[email protected]>
  • Loading branch information
youliangtan authored May 8, 2024
1 parent 6c320f6 commit a6d6527
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 429 deletions.
14 changes: 10 additions & 4 deletions docs/sim_quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ bash run_actor.sh

**✨ One-liner launcher (requires `tmux`) ✨**
```bash
bash examples/async_rlpd_drq_sim/tmux_launch.sh
bash examples/async_drq_sim/tmux_rlpd_launch.sh
```

### Without using one-liner tmux launcher

You can opt for running the commands individually in 2 different terminals.

```bash
cd examples/async_rlpd_drq_sim
cd examples/async_drq_sim

# to use pre-trained ResNet weights, please download
# note manual download is only for now, once repo is public, auto download will work
Expand All @@ -109,9 +109,9 @@ wget \
https://github.com/rail-berkeley/serl/releases/download/franka_sim_lift_cube_demos/franka_lift_cube_image_20_trajs.pkl
```

Run learner node:
Run learner node, while provide the path to the demo trajectories in the `--demo_path` argument.
```bash
bash run_learner.sh
bash run_learner.sh --demo_path franka_lift_cube_image_20_trajs.pkl
```

Run actor node with rendering window:
Expand Down Expand Up @@ -163,3 +163,9 @@ With the example above, we can load the data from the replay buffer by providing
```

This is similar to the `examples/async_rlpd_drq_sim/run_learner.sh` script, which uses `--demo_path` argument which load .pkl offline demo trajectories.


### Troubleshooting

1. If you receive a Out of Memory error, try reducing the batch size in the `run_learner.sh` script. by adding the `--batch_size` argument. For example, `bash run_learner.sh --batch_size 64`.
2. If the provided offline RLDS data is throwing an error, this usually means the data is not compatible with current SERL format. You can provide a custom data transform with the `data_transform(data, metadata) -> data` function in the `examples/async_drq_sim/asyn_drq_sim.py` script.
19 changes: 12 additions & 7 deletions examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from absl import app, flags
from flax.training import checkpoints

import pickle as pkl
import os
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
from serl_launcher.utils.timer_utils import Timer
from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.utils.train_utils import concat_batches
Expand Down Expand Up @@ -389,13 +390,17 @@ def create_replay_buffer_and_wandb_logger():
capacity=10000,
image_keys=image_keys,
)
import pickle as pkl

with open(FLAGS.demo_path, "rb") as f:
trajs = pkl.load(f)
for traj in trajs:
demo_buffer.insert(traj)
print(f"demo buffer size: {len(demo_buffer)}")
if FLAGS.demo_path:
# Check if the file exists
if not os.path.exists(FLAGS.demo_path):
raise FileNotFoundError(f"File {FLAGS.demo_path} not found")

with open(FLAGS.demo_path, "rb") as f:
trajs = pkl.load(f)
for traj in trajs:
demo_buffer.insert(traj)
print(f"demo buffer size: {len(demo_buffer)}")

# learner loop
print_green("starting learner loop")
Expand Down
82 changes: 79 additions & 3 deletions examples/async_drq_sim/async_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@
import tqdm
from absl import app, flags
from flax.training import checkpoints
import cv2
import os

from typing import Any, Dict, Optional
import pickle as pkl
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
from serl_launcher.utils.timer_utils import Timer
from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.utils.train_utils import concat_batches

from agentlace.trainer import TrainerServer, TrainerClient
from agentlace.data.data_store import QueuedDataStore

from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore
from serl_launcher.utils.launcher import (
make_drq_agent,
make_trainer_config,
Expand All @@ -32,7 +38,7 @@

FLAGS = flags.FLAGS

flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.")
flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.")
flags.DEFINE_string("agent", "drq", "Name of agent.")
flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.")
flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.")
Expand All @@ -59,6 +65,7 @@
flags.DEFINE_string("ip", "localhost", "IP address of the learner.")
# "small" is a 4 layer convnet, "resnet" and "mobilenet" are frozen with pretrained weights
flags.DEFINE_string("encoder_type", "resnet-pretrained", "Encoder type.")
flags.DEFINE_string("demo_path", None, "Path to the demo data.")
flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.")
flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.")

Expand Down Expand Up @@ -173,7 +180,13 @@ def update_params(params):
##############################################################################


def learner(rng, agent: DrQAgent, replay_buffer, wandb_logger=None):
def learner(
rng,
agent: DrQAgent,
replay_buffer: MemoryEfficientReplayBufferDataStore,
demo_buffer: Optional[MemoryEfficientReplayBufferDataStore] = None,
wandb_logger=None,
):
"""
The learner loop, which runs when "--learner" is set to True.
"""
Expand Down Expand Up @@ -210,15 +223,33 @@ def stats_callback(type: str, payload: dict) -> dict:
server.publish_network(agent.state.params)
print_green("sent initial network to actor")

# 50/50 sampling from RLPD, half from demo and half from online experience if
# demo_buffer is provided
demo_iterator = None
if demo_buffer is None:
single_buffer_batch_size = FLAGS.batch_size
else:
single_buffer_batch_size = FLAGS.batch_size // 2

# create replay buffer iterator
replay_iterator = replay_buffer.get_iterator(
sample_args={
"batch_size": FLAGS.batch_size,
"batch_size": single_buffer_batch_size,
"pack_obs_and_next_obs": True,
},
device=sharding.replicate(),
)

# if demo_buffer is provided, create demo buffer iterator
if demo_buffer is not None:
demo_iterator = demo_buffer.get_iterator(
sample_args={
"batch_size": single_buffer_batch_size,
"pack_obs_and_next_obs": True,
},
device=sharding.replicate(),
)

# wait till the replay buffer is filled with enough data
timer = Timer()
for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"):
Expand All @@ -228,13 +259,25 @@ def stats_callback(type: str, payload: dict) -> dict:
with timer.context("sample_replay_buffer"):
batch = next(replay_iterator)

# we will concatenate the demo data with the online data
# if demo_buffer is provided
if demo_iterator is not None:
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)

with timer.context("train_critics"):
agent, critics_info = agent.update_critics(
batch,
)

with timer.context("train"):
batch = next(replay_iterator)

# we will concatenate the demo data with the online data
# if demo_buffer is provided
if demo_iterator is not None:
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)

# publish the updated network
Expand All @@ -260,6 +303,7 @@ def stats_callback(type: str, payload: dict) -> dict:

def main(_):
assert FLAGS.batch_size % num_devices == 0

# seed
rng = jax.random.PRNGKey(FLAGS.seed)

Expand Down Expand Up @@ -292,6 +336,12 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]:
# NOTE: Create your own custom data transform function here if you
# are loading this via with --preload_rlds_path with tf rlds data
# This default does nothing
return data

def create_replay_buffer_and_wandb_logger():
replay_buffer = make_replay_buffer(
env,
Expand All @@ -300,6 +350,7 @@ def create_replay_buffer_and_wandb_logger():
type="memory_efficient_replay_buffer",
image_keys=image_keys,
preload_rlds_path=FLAGS.preload_rlds_path,
preload_data_transform=preload_data_transform,
)

# set up wandb and logging
Expand All @@ -314,12 +365,37 @@ def create_replay_buffer_and_wandb_logger():
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()

print_green("replay buffer created")
print_green(f"replay_buffer size: {len(replay_buffer)}")

# if demo data is provided, load it into the demo buffer
# in the learner node
if FLAGS.demo_path:
# Check if the file exists
if not os.path.exists(FLAGS.demo_path):
raise FileNotFoundError(f"File {FLAGS.demo_path} not found")

demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=10000,
image_keys=image_keys,
)
with open(FLAGS.demo_path, "rb") as f:
trajs = pkl.load(f)
for traj in trajs:
demo_buffer.insert(traj)
print(f"demo buffer size: {len(demo_buffer)}")
else:
demo_buffer = None

# learner loop
print_green("starting learner loop")
learner(
sampling_rng,
agent,
replay_buffer,
demo_buffer=demo_buffer, # None if no demo data is provided
wandb_logger=wandb_logger,
)

Expand Down
1 change: 0 additions & 1 deletion examples/async_drq_sim/run_actor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \
python async_drq_sim.py "$@" \
--actor \
--render \
--env PandaPickCubeVision-v0 \
--exp_name=serl_dev_drq_sim_test_resnet \
--seed 0 \
--random_steps 1000 \
Expand Down
3 changes: 1 addition & 2 deletions examples/async_drq_sim/run_learner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \
python async_drq_sim.py "$@" \
--learner \
--env PandaPickCubeVision-v0 \
--exp_name=serl_dev_drq_sim_test_resnet \
--seed 0 \
--random_steps 1000 \
--training_starts 1000 \
--utd_ratio 4 \
--batch_size 256 \
--eval_period 2000 \
--encoder_type resnet-pretrained \
# --demo_path franka_lift_cube_image_20_trajs.pkl \
--debug # wandb is disabled when debug
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# use the default values if the env variables are not set
EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_rlpd_drq_sim"}
EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_drq_sim"}
CONDA_ENV=${CONDA_ENV:-"serl"}

cd $EXAMPLE_DIR
Expand Down Expand Up @@ -35,7 +35,7 @@ tmux split-window -v
tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m

# Navigate to the activate the conda environment in the second pane
tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m
tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh --demo_path franka_lift_cube_image_20_trajs.pkl" C-m

# Attach to the tmux session
tmux attach-session -t serl_session
Expand Down
Loading

0 comments on commit a6d6527

Please sign in to comment.