forked from rll/rllab
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fisrt draft of randomize environment
Implement basic feature of a wrappered environment, which choose new randomized physics params in mujoco on every reset().
- Loading branch information
Chang
authored and
Angel Gonzalez
committed
May 31, 2018
1 parent
71e6bc9
commit 60ea59f
Showing
3 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from lxml import etree | ||
from rllab.envs import Env | ||
import os.path as osp | ||
import numpy as np | ||
from rllab.dynamics_randomization import VariationMethods | ||
from rllab.dynamics_randomization.variation import VariationDistributions | ||
from mujoco_py import load_model_from_xml | ||
from mujoco_py import MjSim | ||
from rllab.core import Serializable | ||
|
||
MODEL_DIR = osp.abspath( | ||
osp.join(osp.dirname(__file__), '../../vendor/mujoco_models')) | ||
|
||
|
||
class RandomizedEnv(Env, Serializable): | ||
def __init__(self, mujoco_env, variations): | ||
Serializable.quick_init(self, locals()) | ||
self._wrapped_env = mujoco_env | ||
self._variations = variations | ||
self._file_path = osp.join(MODEL_DIR, mujoco_env.FILE) | ||
self._model = etree.parse(self._file_path) | ||
|
||
for v in variations.get_list(): | ||
e = self._model.find(v.xpath) | ||
v.elem = e | ||
|
||
# todo: handle AttributeError | ||
val = e.attrib[v.attrib].split(' ') | ||
if len(val) == 1: | ||
v.default = float(e.attrib[v.attrib]) | ||
else: | ||
v.default = np.array(list(map(float, val))) | ||
|
||
def reset(self): | ||
for v in self._variations.get_list(): | ||
e = v.elem | ||
# todo: handle size | ||
if v.distribution == VariationDistributions.GAUSSIAN: | ||
c = np.random.normal(loc=v.var_range[0], scale=v.var_range[1]) | ||
elif v.distribution == VariationDistributions.UNIFORM: | ||
c = np.random.uniform(low=v.var_range[0], high=v.var_range[1]) | ||
if v.method == VariationMethods.COEFFICIENT: | ||
e.attrib[v.attrib] = str(c * v.default) | ||
elif v.method == VariationMethods.ABSOLUTE: | ||
e.attrib[v.attrib] = str(c) | ||
else: | ||
raise NotImplementedError("Unknown method") | ||
|
||
model_xml = etree.tostring(self._model.getroot()).decode("ascii") | ||
self._wrapped_env.model = load_model_from_xml(model_xml) | ||
self._wrapped_env.sim = MjSim(self._wrapped_env.model) | ||
self._wrapped_env.data = self._wrapped_env.sim.data | ||
self._wrapped_env.viewer = None | ||
self._wrapped_env.init_qpos = self._wrapped_env.sim.data.qpos | ||
self._wrapped_env.init_qvel = self._wrapped_env.sim.data.qvel | ||
self._wrapped_env.init_qacc = self._wrapped_env.sim.data.qacc | ||
self._wrapped_env.init_ctrl = self._wrapped_env.sim.data.ctrl | ||
self._wrapped_env.qpos_dim = self._wrapped_env.init_qpos.size | ||
self._wrapped_env.qvel_dim = self._wrapped_env.init_qvel.size | ||
self._wrapped_env.ctrl_dim = self._wrapped_env.init_ctrl.size | ||
self._wrapped_env.frame_skip = 1 | ||
self._wrapped_env.dcom = None | ||
self._wrapped_env.current_com = None | ||
return self._wrapped_env.reset() | ||
|
||
def step(self, action): | ||
return self._wrapped_env.step(action) | ||
|
||
def render(self, *args, **kwargs): | ||
return self._wrapped_env.render(*args, **kwargs) | ||
|
||
def log_diagnostics(self, paths, *args, **kwargs): | ||
self._wrapped_env.log_diagnostics(paths, *args, **kwargs) | ||
|
||
def terminate(self): | ||
self._wrapped_env.terminate() | ||
|
||
def get_param_values(self): | ||
return self._wrapped_env.get_param_values() | ||
|
||
def set_param_values(self, params): | ||
self._wrapped_env.set_param_values(params) | ||
|
||
@property | ||
def wrapped_env(self): | ||
return self._wrapped_env | ||
|
||
@property | ||
def action_space(self): | ||
return self._wrapped_env.action_space | ||
|
||
@property | ||
def observation_space(self): | ||
return self._wrapped_env.observation_space | ||
|
||
@property | ||
def horizon(self): | ||
return self._wrapped_env.horizon | ||
|
||
|
||
randomize = RandomizedEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from rllab.dynamics_randomization.variation import Variations | ||
from rllab.dynamics_randomization.variation import VariationMethods | ||
from rllab.dynamics_randomization.variation import VariationDistributions | ||
from rllab.dynamics_randomization.RandomizeEnv import RandomizedEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from rllab.algos import TRPO | ||
from rllab.baselines import LinearFeatureBaseline | ||
from rllab.envs.mujoco import SwimmerEnv | ||
from rllab.envs import normalize | ||
from rllab.policies import GaussianMLPPolicy | ||
from rllab.dynamics_randomization import RandomizedEnv | ||
from rllab.dynamics_randomization import Variations | ||
from rllab.dynamics_randomization import VariationMethods | ||
from rllab.dynamics_randomization import VariationDistributions | ||
|
||
variations = Variations() | ||
variations.randomize().\ | ||
at_xpath(".//geom[@name='torso']").\ | ||
attribute("density").\ | ||
with_method(VariationMethods.COEFFICIENT).\ | ||
sampled_from(VariationDistributions.UNIFORM).\ | ||
with_range(0.5, 1.5) | ||
|
||
env = normalize(RandomizedEnv(SwimmerEnv(), variations)) | ||
|
||
policy = GaussianMLPPolicy( | ||
env_spec=env.spec, | ||
# The neural network policy should have two hidden layers, each with 32 hidden units. | ||
hidden_sizes=(32, 32)) | ||
|
||
baseline = LinearFeatureBaseline(env_spec=env.spec) | ||
|
||
algo = TRPO( | ||
env=env, | ||
policy=policy, | ||
baseline=baseline, | ||
batch_size=4000, | ||
max_path_length=500, | ||
n_itr=40, | ||
discount=0.99, | ||
step_size=0.01, | ||
# plot=True | ||
) | ||
algo.train() |