Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #586 from deepsense-ai/eval_implemantation
Browse files Browse the repository at this point in the history
Further improvements to rl
  • Loading branch information
lukaszkaiser authored Feb 15, 2018
2 parents 0878242 + 08b15e1 commit 4624641
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 82 deletions.
30 changes: 28 additions & 2 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,21 @@ def ppo_base_v1():
hparams.add_hparam("optimization_epochs", 15)
hparams.add_hparam("epoch_length", 200)
hparams.add_hparam("epochs_num", 2000)
hparams.add_hparam("eval_every_epochs", 10)
hparams.add_hparam("num_eval_agents", 3)
hparams.add_hparam("video_during_eval", True)
return hparams


@registry.register_hparams
def pendulum_base():
def continuous_action_base():
hparams = ppo_base_v1()
hparams.add_hparam("network", feed_forward_gaussian_fun)
return hparams


@registry.register_hparams
def cartpole_base():
def discrete_action_base():
hparams = ppo_base_v1()
hparams.add_hparam("network", feed_forward_categorical_fun)
return hparams
Expand Down Expand Up @@ -129,3 +132,26 @@ def feed_forward_categorical_fun(action_space, config, observations):
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
policy = tf.contrib.distributions.Categorical(logits=logits)
return NetworkOutput(policy, value, lambda a: a)


def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
"""Small cnn network with categorical output."""
obs_shape = observations.shape.as_list()
x = tf.reshape(observations, [-1]+ obs_shape[2:])

with tf.variable_scope('policy'):
x = tf.to_float(x)/255.0
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2], activation_fn= tf.nn.relu, padding="SAME")
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2], activation_fn=tf.nn.relu, padding="SAME")

flat_x = tf.reshape(x, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, x.shape.as_list()[1:], 1)])

x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
logits = tf.contrib.layers.fully_connected(x, action_space.n, activation_fn=None)

value = tf.contrib.layers.fully_connected(x, 1, activation_fn=None)[..., 0]
policy = tf.contrib.distributions.Categorical(logits=logits)

return NetworkOutput(policy, value, lambda a: a)
2 changes: 1 addition & 1 deletion tensor2tensor/rl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Currently the only supported algorithm is Proximy Policy Optimization - PPO.

## Sample usage - training in Pendulum-v0 environment.

```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set pendulum_base [--output_dir dir_location]```
```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]```
37 changes: 26 additions & 11 deletions tensor2tensor/rl/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import tensorflow as tf


def define_collect(policy_factory, batch_env, hparams):
def define_collect(policy_factory, batch_env, hparams, eval_phase):
"""Collect trajectories."""
eval_phase = tf.convert_to_tensor(eval_phase)
memory_shape = [hparams.epoch_length] + [batch_env.observ.shape.as_list()[0]]
memories_shapes_and_types = [
# observation
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
(memory_shape + batch_env.observ.shape.as_list()[1:], tf.float32),
(memory_shape, tf.float32), # reward
(memory_shape, tf.bool), # done
# action
Expand All @@ -33,33 +34,38 @@ def define_collect(policy_factory, batch_env, hparams):
]
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
for (shape, dtype) in memories_shapes_and_types]
cumulative_rewards = tf.Variable(
tf.zeros(hparams.num_agents, tf.float32), trainable=False)
cumulative_rewards = tf.get_variable("cumulative_rewards", len(batch_env),
trainable=False)

should_reset_var = tf.Variable(True, trainable=False)
reset_op = tf.cond(should_reset_var,
lambda: batch_env.reset(tf.range(hparams.num_agents)),
lambda: 0.0)
reset_op = tf.cond(
tf.logical_or(should_reset_var, eval_phase),
lambda: tf.group(batch_env.reset(tf.range(len(batch_env))),
tf.assign(cumulative_rewards, tf.zeros(len(batch_env)))),
lambda: tf.no_op())
with tf.control_dependencies([reset_op]):
reset_once_op = tf.assign(should_reset_var, False)

with tf.control_dependencies([reset_once_op]):

def step(index, scores_sum, scores_num):
"""Single step."""
index = index % hparams.epoch_length # Only needed in eval runs.
# Note - the only way to ensure making a copy of tensor is to run simple
# operation. We are waiting for tf.copy:
# https://github.com/tensorflow/tensorflow/issues/11186
obs_copy = batch_env.observ + 0
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0))
policy = actor_critic.policy
action = policy.sample()
action = tf.cond(eval_phase,
policy.mode,
policy.sample)
postprocessed_action = actor_critic.action_postprocessing(action)
simulate_output = batch_env.simulate(postprocessed_action[0, ...])
pdf = policy.prob(action)[0]
with tf.control_dependencies(simulate_output):
reward, done = simulate_output
done = tf.reshape(done, (hparams.num_agents,))
done = tf.reshape(done, (len(batch_env),))
to_save = [obs_copy, reward, done, action[0, ...], pdf,
actor_critic.value[0]]
save_ops = [tf.scatter_update(memory_slot, index, value)
Expand All @@ -81,9 +87,14 @@ def step(index, scores_sum, scores_num):
return [index + 1, scores_sum + scores_sum_delta,
scores_num + scores_num_delta]

def stop_condition(i, _, resets):
return tf.cond(eval_phase,
lambda: resets < hparams.num_eval_agents,
lambda: i < hparams.epoch_length)

init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
index, scores_sum, scores_num = tf.while_loop(
lambda c, _1, _2: c < hparams.epoch_length,
stop_condition,
step,
init,
parallel_iterations=1,
Expand All @@ -94,7 +105,11 @@ def step(index, scores_sum, scores_num):
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
with tf.control_dependencies([index, printing]):
memory = [tf.identity(mem) for mem in memory]
mean_score_summary = tf.cond(
tf.greater(scores_num, 0),
lambda: tf.summary.scalar("mean_score_this_iter", mean_score),
str)
summaries = tf.summary.merge(
[tf.summary.scalar("mean_score_this_iter", mean_score),
[mean_score_summary,
tf.summary.scalar("episodes_finished_this_iter", scores_num)])
return memory, summaries
78 changes: 74 additions & 4 deletions tensor2tensor/rl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
# https://github.com/tensorflow/agents/blob/master/agents/scripts/utility.py

import atexit
import gym
import multiprocessing
import os
import random
import signal
import subprocess
import sys
import traceback

Expand All @@ -31,6 +36,49 @@
import tensorflow as tf


class EvalVideoWrapper(gym.Wrapper):
"""
Wrapper for recording videos during eval phase.
This wrapper is designed to record videos via gym.wrappers.Monitor and
simplifying its usage in t2t collect phase.
It alleviate the limitation of Monitor, which doesn't allow reset on an
active environment.
EvalVideoWrapper assumes that only every second trajectory (after every
second reset) will be used by the caller:
- on the "active" runs it behaves as gym.wrappers.Monitor,
- on the "inactive" runs it doesn't call underlying environment and only
returns last seen observation.
Videos are only generated during the active runs.
"""
def __init__(self, env):
super(EvalVideoWrapper, self).__init__(env)
self._reset_counter = 0
self._active = False
self._last_returned = None

def _step(self, action):
if self._active:
self._last_returned = self.env.step(action)
if self._last_returned == None:
raise Exception("Environment stepped before proper reset.")
return self._last_returned

def _reset(self, **kwargs):
self._reset_counter += 1
if self._reset_counter % 2 == 1:
self._active = True
return self.env.reset(**kwargs)
else:
self._active = False
self._last_returned = (self._last_returned[0],
self._last_returned[1],
False, # done = False
self._last_returned[3])
return self._last_returned[0]


class ExternalProcessEnv(object):
"""Step environment in a separate process for lock free paralellism."""

Expand All @@ -41,7 +89,7 @@ class ExternalProcessEnv(object):
_EXCEPTION = 4
_CLOSE = 5

def __init__(self, constructor):
def __init__(self, constructor, xvfb):
"""Step environment in a separate process for lock free paralellism.
The environment will be created in the external process by calling the
Expand All @@ -57,8 +105,30 @@ def __init__(self, constructor):
action_space: The cached action space of the environment.
"""
self._conn, conn = multiprocessing.Pipe()
self._process = multiprocessing.Process(
if xvfb:
server_id = random.randint(10000, 99999)
auth_file_id = random.randint(10000, 99999999999)

xauthority_path = '/tmp/Xauthority_{}'.format(auth_file_id)

command = 'Xvfb :{} -screen 0 1400x900x24 -nolisten tcp -auth {}'.format(
server_id, xauthority_path)
with open(os.devnull, 'w') as devnull:
proc = subprocess.Popen(command.split(), shell=False, stdout=devnull,
stderr=devnull)
atexit.register(lambda: os.kill(proc.pid, signal.SIGKILL))

def constructor_using_xvfb():
os.environ["DISPLAY"] = ":{}".format(server_id)
os.environ["XAUTHORITY"] = xauthority_path
return constructor()

self._process = multiprocessing.Process(
target=self._worker, args=(constructor_using_xvfb, conn))
else:
self._process = multiprocessing.Process(
target=self._worker, args=(constructor, conn))

atexit.register(self.close)
self._process.start()
self._observ_space = None
Expand Down Expand Up @@ -206,7 +276,7 @@ def _worker(self, constructor, conn):
conn.close()


def define_batch_env(constructor, num_agents, env_processes=True):
def define_batch_env(constructor, num_agents, xvfb=False, env_processes=True):
"""Create environments and apply all desired wrappers.
Args:
Expand All @@ -220,7 +290,7 @@ def define_batch_env(constructor, num_agents, env_processes=True):
with tf.variable_scope('environments'):
if env_processes:
envs = [
ExternalProcessEnv(constructor)
ExternalProcessEnv(constructor, xvfb)
for _ in range(num_agents)]
else:
envs = [constructor() for _ in range(num_agents)]
Expand Down
Loading

0 comments on commit 4624641

Please sign in to comment.