Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #568 from deepsense-ai/gae_support
Browse files Browse the repository at this point in the history
RL Improvements
  • Loading branch information
lukaszkaiser authored Feb 10, 2018
2 parents 28d1841 + ad30518 commit eaefc32
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 137 deletions.
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tensor2tensor.models import neural_gpu
from tensor2tensor.models import resnet
from tensor2tensor.models import revnet
from tensor2tensor.models import rl
from tensor2tensor.models import shake_shake
from tensor2tensor.models import slicenet
from tensor2tensor.models import super_lm
Expand Down
127 changes: 127 additions & 0 deletions tensor2tensor/models/rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Reinforcement learning models and parameters."""

# Dependency imports

import collections
import functools
import gym
import operator
import tensorflow as tf

from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import registry


@registry.register_hparams
def ppo_base_v1():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.learning_rate = 1e-4
hparams.add_hparam("init_mean_factor", 0.1)
hparams.add_hparam("init_logstd", 0.1)
hparams.add_hparam("policy_layers", (100, 100))
hparams.add_hparam("value_layers", (100, 100))
hparams.add_hparam("num_agents", 30)
hparams.add_hparam("clipping_coef", 0.2)
hparams.add_hparam("gae_gamma", 0.99)
hparams.add_hparam("gae_lambda", 0.95)
hparams.add_hparam("entropy_loss_coef", 0.01)
hparams.add_hparam("value_loss_coef", 1)
hparams.add_hparam("optimization_epochs", 15)
hparams.add_hparam("epoch_length", 200)
hparams.add_hparam("epochs_num", 2000)
return hparams

@registry.register_hparams
def pendulum():
hparams = ppo_base_v1()
hparams.add_hparam("environment", "Pendulum-v0")
hparams.add_hparam("network", feed_forward_gaussian_fun)
return hparams

@registry.register_hparams
def cartpole():
hparams = ppo_base_v1()
hparams.add_hparam("environment", "CartPole-v0")
hparams.add_hparam("network", feed_forward_categorical_fun)
return hparams


# Neural networks for actor-critic algorithms

NetworkOutput = collections.namedtuple(
'NetworkOutput', 'policy, value, action_postprocessing')


def feed_forward_gaussian_fun(action_space, config, observations):
assert isinstance(action_space, gym.spaces.box.Box), \
'Expecting continuous action space.'
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=config.init_mean_factor)
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)

flat_observations = tf.reshape(observations, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])

with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, action_space.shape[0], tf.tanh,
weights_initializer=mean_weights_initializer)
logstd = tf.get_variable(
'logstd', mean.shape[2:], tf.float32, logstd_initializer)
logstd = tf.tile(
logstd[None, None],
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')

policy = tf.contrib.distributions.MultivariateNormalDiag(mean,
tf.exp(logstd))

return NetworkOutput(policy, value, lambda a: tf.clip_by_value(a, -2., 2))


def feed_forward_categorical_fun(action_space, config, observations):
assert isinstance(action_space, gym.spaces.Discrete), \
'Expecting discrete action space.'
flat_observations = tf.reshape(observations, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
logits = tf.contrib.layers.fully_connected(x, action_space.n,
activation_fn=None)
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
policy = tf.contrib.distributions.Categorical(logits=logits)
return NetworkOutput(policy, value, lambda a: a)
2 changes: 1 addition & 1 deletion tensor2tensor/rl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ Currently the only supported algorithm is Proximy Policy Optimization - PPO.

## Sample usage - training in Pendulum-v0 environment.

```python rl/t2t_rl_trainer.py```
```python rl/t2t_rl_trainer.py --hparams_set pendulum [--output_dir dir_location]```
23 changes: 14 additions & 9 deletions tensor2tensor/rl/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,27 @@
import tensorflow as tf


def define_collect(policy_factory, batch_env, config):
def define_collect(policy_factory, batch_env, hparams):
"""Collect trajectories."""
memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]]
memory_shape = [hparams.epoch_length] + [batch_env.observ.shape.as_list()[0]]
memories_shapes_and_types = [
# observation
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
(memory_shape, tf.float32), # reward
(memory_shape, tf.bool), # done
(memory_shape + batch_env.action_shape, tf.float32), # action
# action
(memory_shape + batch_env.action_shape, batch_env.action_dtype),
(memory_shape, tf.float32), # pdf
(memory_shape, tf.float32), # value function
]
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
for (shape, dtype) in memories_shapes_and_types]
cumulative_rewards = tf.Variable(
tf.zeros(config.num_agents, tf.float32), trainable=False)
tf.zeros(hparams.num_agents, tf.float32), trainable=False)

should_reset_var = tf.Variable(True, trainable=False)
reset_op = tf.cond(should_reset_var,
lambda: batch_env.reset(tf.range(config.num_agents)),
lambda: batch_env.reset(tf.range(hparams.num_agents)),
lambda: 0.0)
with tf.control_dependencies([reset_op]):
reset_once_op = tf.assign(should_reset_var, False)
Expand All @@ -58,7 +59,7 @@ def step(index, scores_sum, scores_num):
pdf = policy.prob(action)[0]
with tf.control_dependencies(simulate_output):
reward, done = simulate_output
done = tf.reshape(done, (config.num_agents,))
done = tf.reshape(done, (hparams.num_agents,))
to_save = [obs_copy, reward, done, action[0, ...], pdf,
actor_critic.value[0]]
save_ops = [tf.scatter_update(memory_slot, index, value)
Expand All @@ -82,7 +83,7 @@ def step(index, scores_sum, scores_num):

init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
index, scores_sum, scores_num = tf.while_loop(
lambda c, _1, _2: c < config.epoch_length,
lambda c, _1, _2: c < hparams.epoch_length,
step,
init,
parallel_iterations=1,
Expand All @@ -91,5 +92,9 @@ def step(index, scores_sum, scores_num):
lambda: scores_sum / tf.cast(scores_num, tf.float32),
lambda: 0.)
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
with tf.control_dependencies([printing]):
return tf.identity(index), memory
with tf.control_dependencies([index, printing]):
memory = [tf.identity(mem) for mem in memory]
summaries = tf.summary.merge(
[tf.summary.scalar("mean_score_this_iter", mean_score),
tf.summary.scalar("episodes_finished_this_iter", scores_num)])
return memory, summaries
69 changes: 0 additions & 69 deletions tensor2tensor/rl/networks.py

This file was deleted.

43 changes: 35 additions & 8 deletions tensor2tensor/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def define_ppo_step(observation, action, reward, done, value, old_pdf,
clipped_ratio = tf.clip_by_value(ratio, 1 - config.clipping_coef,
1 + config.clipping_coef)

advantage = calculate_discounted_return(
reward, value, done, config.gae_gamma, config.gae_lambda) - value
advantage = calculate_generalized_advantage_estimator(
reward, value, done, config.gae_gamma, config.gae_lambda)

advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1],
keep_dims=True)
Expand All @@ -53,7 +53,11 @@ def define_ppo_step(observation, action, reward, done, value, old_pdf,

total_loss = policy_loss + value_loss + entropy_loss

optimization_op = config.optimizer(config.learning_rate).minimize(total_loss)
optimization_op = tf.contrib.layers.optimize_loss(
loss=total_loss,
global_step=tf.train.get_or_create_global_step(),
optimizer=config.optimizer,
learning_rate=config.learning_rate)

with tf.control_dependencies([optimization_op]):
return [tf.identity(x) for x in (policy_loss, value_loss, entropy_loss)]
Expand All @@ -79,12 +83,17 @@ def define_ppo_epoch(memory, policy_factory, config):
[0., 0., 0.],
parallel_iterations=1)

print_losses = tf.group(
tf.Print(0, [tf.reduce_mean(policy_loss)], 'policy loss: '),
tf.Print(0, [tf.reduce_mean(value_loss)], 'value loss: '),
tf.Print(0, [tf.reduce_mean(entropy_loss)], 'entropy loss: '))
summaries = [tf.summary.scalar("policy loss", tf.reduce_mean(policy_loss)),
tf.summary.scalar("value loss", tf.reduce_mean(value_loss)),
tf.summary.scalar("entropy loss", tf.reduce_mean(entropy_loss))]

return print_losses
losses_summary = tf.summary.merge(summaries)

losses_summary = tf.Print(losses_summary, [tf.reduce_mean(policy_loss)], 'policy loss: ')
losses_summary = tf.Print(losses_summary, [tf.reduce_mean(value_loss)], 'value loss: ')
losses_summary = tf.Print(losses_summary, [tf.reduce_mean(entropy_loss)], 'entropy loss: ')

return losses_summary


def calculate_discounted_return(reward, value, done, discount, unused_lambda):
Expand All @@ -100,3 +109,21 @@ def calculate_discounted_return(reward, value, done, discount, unused_lambda):
1,
False), [0])
return tf.check_numerics(return_, 'return')


def calculate_generalized_advantage_estimator(reward, value, done, gae_gamma, gae_lambda):
"""Generalized advantage estimator"""

#Below is slight wierdness, we set the last reward to 0.
# This makes the adventantage to be 0 in the last timestep
reward = tf.concat([reward[:-1,:], value[-1:,:]], axis=0)
next_value = tf.concat([value[1:,:], tf.zeros_like(value[-1:, :])], axis=0)
next_not_done = 1 - tf.cast(tf.concat([done[1:, :], tf.zeros_like(done[-1:, :])], axis=0), tf.float32)
delta = reward + gae_gamma * next_value * next_not_done - value

return_ = tf.reverse(tf.scan(
lambda agg, cur: cur[0] + cur[1] * gae_gamma * gae_lambda * agg,
[tf.reverse(delta, [0]), tf.reverse(next_not_done, [0])],
tf.zeros_like(delta[0, :]),
1, False), [0])
return tf.check_numerics(tf.stop_gradient(return_), 'return')
Loading

0 comments on commit eaefc32

Please sign in to comment.