Skip to content

Commit

Permalink
Refactor serl pipeline (#52)
Browse files Browse the repository at this point in the history
* update example pipelines for reward classfier

Signed-off-by: youliangtan <[email protected]>

* reward classifier fix checkpoint save load

Signed-off-by: youliang <[email protected]>

* clean with nit

Signed-off-by: youliangtan <[email protected]>

* prevent fast network update from learner, configuratble discount, update agentlace version

Signed-off-by: youliangtan <[email protected]>

---------

Signed-off-by: youliangtan <[email protected]>
Signed-off-by: youliang <[email protected]>
  • Loading branch information
youliangtan authored Jun 6, 2024
1 parent cdb787d commit 7d000b2
Show file tree
Hide file tree
Showing 18 changed files with 198 additions and 211 deletions.
23 changes: 19 additions & 4 deletions docs/real_franka.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ bash run_bc.sh
In this cable routing task, we provided an example of an image-based reward classifier. This replaced the hardcoded reward classifier which depends on the known `TARGET_POSE` defined in the `config.py`. This image-based reward classifier is pretrained ResNet10, then trained to classify whether the cable is routed successfully or not. The reward classifier is trained with demo trajectories of successful and failed samples.
```bash
# NOTE: custom paths are used in this script
python train_reward_classifier.py
# NOTE: populate the custom paths to train a reward classifier
python train_reward_classifier.py \
--classifier_ckpt_path CHECKPOINT_OUTPUT_DIR \
--positive_demo_paths PATH_TO_POSITIVE_DEMO1.pkl \
--positive_demo_paths PATH_TO_POSITIVE_DEMO2.pkl \
--negative_demo_paths PATH_TO_NEGATIVE_DEMO1.pkl \
```
The reward classifier is used as a gym wrapper `franka_env.envs.wrapper.BinaryRewardClassifier`. The wrapper classifies the current observation and returns a reward of 1 if the observation is classified as successful, and 0 otherwise.
Expand All @@ -138,11 +142,22 @@ This bin relocation example demonstrates the usage of forward and backward polic
1. Record demo trajectories
Multiple utility scripts have been provided to record demo trajectories. (e.g. `record_demo.py`: for RLPD, `record_transitions.py`: for reward classifier, `reward_bc_demos.py`: for bc policy). Note that both forward and backward trajectories require different demo trajectories.
Multiple utility scripts have been provided to record demo trajectories. (e.g. `record_demo.py`: for RLPD, `record_transitions.py` for training the reward classifier, `reward_bc_demos.py`: for bc policy). Note that both forward and backward trajectories require different demo trajectories.
2. Reward Classifier
Similar to the cable routing example, we need to train two reward classifiers for both forward and backward policies, shown in `train_fwd_reward_classifier.sh` and `train_bwd_reward_classifier.sh`. The reward classifiers are then used in the BC and DRQ policy for the actor node, checkpoint path is provided as `--reward_classifier_ckpt_path` argument in `run_bc.sh` and `run_actor.sh`.
Similar to the cable routing example, we need to train two reward classifiers for both forward and backward policies. Since the observations has both wrist camera and front camera, we use a `FrontCameraWrapper(env)` to only provide the front camera image to the reward classifier.
```bash
# NOTE: populate the custom paths to train reward classifiers for both forward and backward policies
python train_reward_classifier.py \
--classifier_ckpt_path CHECKPOINT_OUTPUT_DIR \
--positive_demo_paths PATH_TO_POSITIVE_DEMO1.pkl \
--positive_demo_paths PATH_TO_POSITIVE_DEMO2.pkl \
--negative_demo_paths PATH_TO_NEGATIVE_DEMO1.pkl \
```
The reward classifiers are then used in the BC and DRQ policy for the actor node, checkpoint path is provided as `--fw_reward_classifier_ckpt_path` and `--bw_reward_classifier_ckpt_path` argument in `run_actor.sh`. To compare with BC as baseline, provide the classifier as `--reward_classifier_ckpt_path` for the `run_bc.sh` script.
3. Run 2 learners and 1 actor with 2 policies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def check_all_done():
obs = next_obs

if done:
print(rew)
env.set_task_id(env.task_graph())
print(f"current task id: {env.task_id}")
obs, _ = env.reset()
Expand Down
9 changes: 0 additions & 9 deletions examples/async_bin_relocation_fwbw_drq/train_bw_classifier.sh

This file was deleted.

9 changes: 0 additions & 9 deletions examples/async_bin_relocation_fwbw_drq/train_fw_classifier.sh

This file was deleted.

73 changes: 45 additions & 28 deletions examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import pickle as pkl
import jax
from jax import numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
from flax.core import frozen_dict
from flax.training import checkpoints
import optax
from tqdm import tqdm
import gym
import os
from absl import app, flags

from serl_launcher.vision.resnet_v1 import resnetv1_configs, PreTrainedResNetEncoder
from serl_launcher.common.encoding import EncodingWrapper
from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.utils.train_utils import concat_batches
from serl_launcher.vision.data_augmentations import batched_random_crop
Expand All @@ -25,62 +22,76 @@
)
from serl_launcher.networks.reward_classifier import create_classifier

from franka_env.envs.wrappers import (
Quat2EulerWrapper,
)
import franka_env
from franka_env.envs.wrappers import Quat2EulerWrapper
from franka_env.envs.relative_env import RelativeFrame

# Set above env export to prevent OOM errors from memory preallocation
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2"

FLAGS = flags.FLAGS
flags.DEFINE_multi_string("positive_demo_paths", None, "paths to positive demos")
flags.DEFINE_multi_string("negative_demo_paths", None, "paths to negative demos")
flags.DEFINE_string("classifier_ckpt_path", None, "Path to classifier checkpoint")
flags.DEFINE_string("classifier_ckpt_path", ".", "Path to classifier checkpoint")
flags.DEFINE_integer("batch_size", 256, "Batch size for training")
flags.DEFINE_integer("num_epochs", 100, "Number of epochs for training")


def main(_):
num_epochs = 100
batch_size = 256

devices = jax.local_devices()
num_devices = len(devices)
sharding = jax.sharding.PositionalSharding(devices)

env = gym.make("FrankaBinRelocation-Vision-v0", save_video=False)
env = RelativeFrame(env)
env = Quat2EulerWrapper(env)
env = SERLObsWrapper(env)
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
env = FrontCameraWrapper(env)
image_keys = [k for k in env.front_observation_space.keys() if "state" not in k]

# we will only use the front camera view for training the reward classifier
train_reward_classifier(env.front_observation_space, env.action_space)


def train_reward_classifier(observation_space, action_space):
"""
User can provide custom observation space to be used as the
input to the classifier. This function is used to train a reward
classifier using the provided positive and negative demonstrations.
NOTE: this function is duplicated and used in both
async_bin_relocation_fwbw_drq and async_cable_route_drq examples
"""
devices = jax.local_devices()
sharding = jax.sharding.PositionalSharding(devices)

image_keys = [k for k in observation_space.keys() if "state" not in k]

pos_buffer = MemoryEfficientReplayBufferDataStore(
env.front_observation_space,
env.action_space,
capacity=2000,
observation_space,
action_space,
capacity=10000,
image_keys=image_keys,
)
pos_buffer = populate_data_store(pos_buffer, FLAGS.positive_demo_paths)

neg_buffer = MemoryEfficientReplayBufferDataStore(
env.front_observation_space,
env.action_space,
capacity=5000,
observation_space,
action_space,
capacity=10000,
image_keys=image_keys,
)
neg_buffer = populate_data_store(neg_buffer, FLAGS.negative_demo_paths)

print(f"failed buffer size: {len(neg_buffer)}")
print(f"success buffer size: {len(pos_buffer)}")

pos_iterator = pos_buffer.get_iterator(
sample_args={
"batch_size": batch_size // 2,
"batch_size": FLAGS.batch_size // 2,
"pack_obs_and_next_obs": False,
},
device=sharding.replicate(),
)
neg_iterator = neg_buffer.get_iterator(
sample_args={
"batch_size": batch_size // 2,
"batch_size": FLAGS.batch_size // 2,
"pack_obs_and_next_obs": False,
},
device=sharding.replicate(),
Expand Down Expand Up @@ -125,7 +136,7 @@ def loss_fn(params):
return state.apply_gradients(grads=grads), loss, train_accuracy

# Training Loop
for epoch in tqdm(range(num_epochs)):
for epoch in tqdm(range(FLAGS.num_epochs)):
# Sample equal number of positive and negative examples
pos_sample = next(pos_iterator)
neg_sample = next(neg_iterator)
Expand All @@ -136,7 +147,11 @@ def loss_fn(params):
rng, key = jax.random.split(rng)
sample = data_augmentation_fn(key, sample)
labels = jnp.concatenate(
[jnp.ones((batch_size // 2, 1)), jnp.zeros((batch_size // 2, 1))], axis=0
[
jnp.ones((FLAGS.batch_size // 2, 1)),
jnp.zeros((FLAGS.batch_size // 2, 1)),
],
axis=0,
)
batch = {"data": sample, "labels": labels}

Expand All @@ -147,10 +162,12 @@ def loss_fn(params):
f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}"
)

# this is used to save the without the orbax checkpointing
flax.config.update("flax_use_orbax_checkpointing", False)
checkpoints.save_checkpoint(
FLAGS.classifier_ckpt_path,
classifier,
step=num_epochs,
step=FLAGS.num_epochs,
overwrite=True,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
make_wandb_logger,
)
from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper
from serl_launcher.networks.reward_classifier import load_classifier_func
from franka_env.envs.relative_env import RelativeFrame
from franka_env.envs.wrappers import (
GripperCloseEnv,
Expand All @@ -55,7 +56,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down Expand Up @@ -337,7 +338,6 @@ def main(_):
image_keys = [key for key in env.observation_space.keys() if key != "state"]
if FLAGS.actor:
# initialize the classifier and wrap the env
from serl_launcher.networks.reward_classifier import load_classifier_func

if FLAGS.reward_classifier_ckpt_path is None:
raise ValueError("reward_classifier_ckpt_path must be specified for actor")
Expand Down
7 changes: 2 additions & 5 deletions examples/async_cable_route_drq/record_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
)

from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper

from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.networks.reward_classifier import load_classifier_func
import jax

if __name__ == "__main__":
env = gym.make("FrankaCableRoute-Vision-v0", save_video=False)
Expand All @@ -30,9 +31,6 @@
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
image_keys = [k for k in env.observation_space.keys() if "state" not in k]

from serl_launcher.networks.reward_classifier import load_classifier_func
import jax

rng = jax.random.PRNGKey(0)
rng, key = jax.random.split(rng)
classifier_func = load_classifier_func(
Expand Down Expand Up @@ -80,7 +78,6 @@
)
)
transitions.append(transition)

obs = next_obs

if done:
Expand Down
29 changes: 19 additions & 10 deletions examples/async_cable_route_drq/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import gym
from tqdm import tqdm
import numpy as np
import copy

import franka_env

from franka_env.envs.relative_env import RelativeFrame
Expand All @@ -14,10 +12,20 @@
)

from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper

from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.networks.reward_classifier import load_classifier_func

if __name__ == "__main__":
import jax
from absl import app, flags

FLAGS = flags.FLAGS

flags.DEFINE_string(
"reward_classifier_ckpt_path", None, "Path to reward classifier ckpt."
)


def main(_):
env = gym.make("FrankaCableRoute-Vision-v0", save_video=False)
env = GripperCloseEnv(env)
env = SpacemouseIntervention(env)
Expand All @@ -27,29 +35,30 @@
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
image_keys = [k for k in env.observation_space.keys() if "state" not in k]

from serl_launcher.networks.reward_classifier import load_classifier_func
import jax

rng = jax.random.PRNGKey(0)
rng, key = jax.random.split(rng)
classifier_func = load_classifier_func(
key=key,
sample=env.observation_space.sample(),
image_keys=image_keys,
checkpoint_path="/home/undergrad/code/serl_dev/examples/async_cable_route_drq/classifier_ckpt/",
checkpoint_path=FLAGS.reward_classifier_ckpt_path,
)
env = BinaryRewardClassifierWrapper(env, classifier_func)

obs, _ = env.reset()

for i in tqdm(range(1000)):
actions = np.zeros((6,))
next_obs, rew, done, truncated, info = env.step(action=actions)

if "intervene_action" in info:
actions = info["intervene_action"]

obs = next_obs

if done:
print(rew)
print("Reward: ", rew)
obs, _ = env.reset()


if __name__ == "__main__":
app.run(main)
Loading

0 comments on commit 7d000b2

Please sign in to comment.