Skip to content

Commit

Permalink
Add fisrt draft of randomize environment
Browse files Browse the repository at this point in the history
Implement basic feature of a wrappered environment, which choose new
randomized physics params in mujoco on every reset().
  • Loading branch information
Chang authored and Angel Gonzalez committed May 31, 2018
1 parent 71e6bc9 commit 60ea59f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
101 changes: 101 additions & 0 deletions rllab/dynamics_randomization/RandomizeEnv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from lxml import etree
from rllab.envs import Env
import os.path as osp
import numpy as np
from rllab.dynamics_randomization import VariationMethods
from rllab.dynamics_randomization.variation import VariationDistributions
from mujoco_py import load_model_from_xml
from mujoco_py import MjSim
from rllab.core import Serializable

MODEL_DIR = osp.abspath(
osp.join(osp.dirname(__file__), '../../vendor/mujoco_models'))


class RandomizedEnv(Env, Serializable):
def __init__(self, mujoco_env, variations):
Serializable.quick_init(self, locals())
self._wrapped_env = mujoco_env
self._variations = variations
self._file_path = osp.join(MODEL_DIR, mujoco_env.FILE)
self._model = etree.parse(self._file_path)

for v in variations.get_list():
e = self._model.find(v.xpath)
v.elem = e

# todo: handle AttributeError
val = e.attrib[v.attrib].split(' ')
if len(val) == 1:
v.default = float(e.attrib[v.attrib])
else:
v.default = np.array(list(map(float, val)))

def reset(self):
for v in self._variations.get_list():
e = v.elem
# todo: handle size
if v.distribution == VariationDistributions.GAUSSIAN:
c = np.random.normal(loc=v.var_range[0], scale=v.var_range[1])
elif v.distribution == VariationDistributions.UNIFORM:
c = np.random.uniform(low=v.var_range[0], high=v.var_range[1])
if v.method == VariationMethods.COEFFICIENT:
e.attrib[v.attrib] = str(c * v.default)
elif v.method == VariationMethods.ABSOLUTE:
e.attrib[v.attrib] = str(c)
else:
raise NotImplementedError("Unknown method")

model_xml = etree.tostring(self._model.getroot()).decode("ascii")
self._wrapped_env.model = load_model_from_xml(model_xml)
self._wrapped_env.sim = MjSim(self._wrapped_env.model)
self._wrapped_env.data = self._wrapped_env.sim.data
self._wrapped_env.viewer = None
self._wrapped_env.init_qpos = self._wrapped_env.sim.data.qpos
self._wrapped_env.init_qvel = self._wrapped_env.sim.data.qvel
self._wrapped_env.init_qacc = self._wrapped_env.sim.data.qacc
self._wrapped_env.init_ctrl = self._wrapped_env.sim.data.ctrl
self._wrapped_env.qpos_dim = self._wrapped_env.init_qpos.size
self._wrapped_env.qvel_dim = self._wrapped_env.init_qvel.size
self._wrapped_env.ctrl_dim = self._wrapped_env.init_ctrl.size
self._wrapped_env.frame_skip = 1
self._wrapped_env.dcom = None
self._wrapped_env.current_com = None
return self._wrapped_env.reset()

def step(self, action):
return self._wrapped_env.step(action)

def render(self, *args, **kwargs):
return self._wrapped_env.render(*args, **kwargs)

def log_diagnostics(self, paths, *args, **kwargs):
self._wrapped_env.log_diagnostics(paths, *args, **kwargs)

def terminate(self):
self._wrapped_env.terminate()

def get_param_values(self):
return self._wrapped_env.get_param_values()

def set_param_values(self, params):
self._wrapped_env.set_param_values(params)

@property
def wrapped_env(self):
return self._wrapped_env

@property
def action_space(self):
return self._wrapped_env.action_space

@property
def observation_space(self):
return self._wrapped_env.observation_space

@property
def horizon(self):
return self._wrapped_env.horizon


randomize = RandomizedEnv
1 change: 1 addition & 0 deletions rllab/dynamics_randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rllab.dynamics_randomization.variation import Variations
from rllab.dynamics_randomization.variation import VariationMethods
from rllab.dynamics_randomization.variation import VariationDistributions
from rllab.dynamics_randomization.RandomizeEnv import RandomizedEnv
39 changes: 39 additions & 0 deletions rllab/dynamics_randomization/trpo_swimmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from rllab.algos import TRPO
from rllab.baselines import LinearFeatureBaseline
from rllab.envs.mujoco import SwimmerEnv
from rllab.envs import normalize
from rllab.policies import GaussianMLPPolicy
from rllab.dynamics_randomization import RandomizedEnv
from rllab.dynamics_randomization import Variations
from rllab.dynamics_randomization import VariationMethods
from rllab.dynamics_randomization import VariationDistributions

variations = Variations()
variations.randomize().\
at_xpath(".//geom[@name='torso']").\
attribute("density").\
with_method(VariationMethods.COEFFICIENT).\
sampled_from(VariationDistributions.UNIFORM).\
with_range(0.5, 1.5)

env = normalize(RandomizedEnv(SwimmerEnv(), variations))

policy = GaussianMLPPolicy(
env_spec=env.spec,
# The neural network policy should have two hidden layers, each with 32 hidden units.
hidden_sizes=(32, 32))

baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
env=env,
policy=policy,
baseline=baseline,
batch_size=4000,
max_path_length=500,
n_itr=40,
discount=0.99,
step_size=0.01,
# plot=True
)
algo.train()

0 comments on commit 60ea59f

Please sign in to comment.