Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #770 from deepsense-ai/more_games
Browse files Browse the repository at this point in the history
Add more tari games.
  • Loading branch information
lukaszkaiser authored May 7, 2018
2 parents 6cea4c4 + 032b595 commit 4ebb860
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
71 changes: 60 additions & 11 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("agent_policy_path", "", "File with model for pong")
flags.DEFINE_string("agent_policy_path", "", "File with model for agent")


class GymDiscreteProblem(video_utils.VideoProblem):
Expand Down Expand Up @@ -99,6 +99,14 @@ def env(self):
def num_actions(self):
return self.env.action_space.n

@property
def frame_height(self):
return self.env.observation_space.shape[0]

@property
def frame_width(self):
return self.env.observation_space.shape[1]

@property
def num_rewards(self):
raise NotImplementedError()
Expand Down Expand Up @@ -150,14 +158,6 @@ class GymPongRandom5k(GymDiscreteProblem):
def env_name(self):
return "PongDeterministic-v4"

@property
def frame_height(self):
return 210

@property
def frame_width(self):
return 160

@property
def min_reward(self):
return -1
Expand All @@ -179,9 +179,38 @@ class GymPongRandom50k(GymPongRandom5k):
def num_steps(self):
return 50000

@registry.register_problem
class GymFreewayRandom5k(GymDiscreteProblem):
"""Freeway game, random actions."""

@property
def env_name(self):
return "FreewayDeterministic-v4"

@property
def min_reward(self):
return 0

@property
def num_rewards(self):
return 2

@property
def num_steps(self):
return 5000


@registry.register_problem
class GymFreewayRandom50k(GymFreewayRandom5k):
"""Freeway game, random actions."""

@property
def num_steps(self):
return 50000


@registry.register_problem
class GymDiscreteProblemWithAgent(GymPongRandom5k):
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
"""Gym environment with discrete actions and rewards and an agent."""

def __init__(self, *args, **kwargs):
Expand All @@ -190,7 +219,7 @@ def __init__(self, *args, **kwargs):
self.debug_dump_frames_path = "debug_frames_env"

# defaults
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
self.environment_spec = lambda: gym.make(self.env_name)
self.in_graph_wrappers = []
self.collect_hparams = rl.atari_base()
self.settable_num_steps = 20000
Expand Down Expand Up @@ -286,3 +315,23 @@ def restore_networks(self, sess):
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
ckpt = ckpts.model_checkpoint_path
env_model_loader.restore(sess, ckpt)


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
pass
41 changes: 14 additions & 27 deletions tensor2tensor/rl/envs/simulated_batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@
from tensor2tensor.utils import trainer_lib

import tensorflow as tf
import numpy as np


flags = tf.flags
FLAGS = flags.FLAGS


flags.DEFINE_string("frames_path", "", "Path to the first frames.")


class SimulatedBatchEnv(InGraphBatchEnv):
"""Batch of environments inside the TensorFlow graph.
Expand All @@ -49,42 +47,31 @@ class SimulatedBatchEnv(InGraphBatchEnv):
flags are held in according variables.
"""

def __init__(self, length, observ_shape, observ_dtype, action_shape,
action_dtype):
def __init__(self, environment_lambda, length):
"""Batch of environments inside the TensorFlow graph."""
self.length = length
initalization_env = environment_lambda()
hparams = trainer_lib.create_hparams(
FLAGS.hparams_set, problem_name=FLAGS.problem, data_dir="UNUSED")
hparams.force_full_predict = True
self._model = registry.model(FLAGS.model)(
hparams, tf.estimator.ModeKeys.PREDICT)

self.action_shape = action_shape
self.action_dtype = action_dtype

with open(os.path.join(FLAGS.frames_path, "frame1.png"), "rb") as f:
png_frame_1_raw = f.read()
self.action_space = initalization_env.action_space
self.action_shape = list(initalization_env.action_space.shape)
self.action_dtype = tf.int32

with open(os.path.join(FLAGS.frames_path, "frame2.png"), "rb") as f:
png_frame_2_raw = f.read()
obs_1 = initalization_env.reset()
obs_2 = initalization_env.step(0)[0]

self.frame_1 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_1_raw),
tf.float32), 0)
self.frame_2 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_2_raw),
tf.float32), 0)
self.frame_1 = tf.expand_dims(tf.cast(obs_1, tf.float32), 0)
self.frame_2 = tf.expand_dims(tf.cast(obs_2, tf.float32), 0)

shape = (self.length,) + observ_shape
self._observ = tf.Variable(tf.zeros(shape, observ_dtype), trainable=False)
self._prev_observ = tf.Variable(tf.zeros(shape, observ_dtype),
shape = (self.length,) + initalization_env.observation_space.shape
# TODO(blazej0) - make more generic - make higher number of previous observations possible.
self._observ = tf.Variable(tf.zeros(shape, tf.float32), trainable=False)
self._prev_observ = tf.Variable(tf.zeros(shape, tf.float32),
trainable=False)
self._starting_observ = tf.Variable(tf.zeros(shape, observ_dtype),
trainable=False)

observ_dtype = tf.int64

@property
def action_space(self):
return gym.make("PongNoFrameskip-v4").action_space

def __len__(self):
"""Number of combined environments."""
Expand Down
12 changes: 3 additions & 9 deletions tensor2tensor/rl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False):
hparams, "in_graph_wrappers") else []

if hparams.simulated_environment:
cur_batch_env = define_simulated_batch_env(num_agents)
cur_batch_env = define_simulated_batch_env(environment_lambda, num_agents)
else:
cur_batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb)
for w in wrappers:
Expand All @@ -306,12 +306,6 @@ def define_batch_env(constructor, num_agents, xvfb=False):
return env


def define_simulated_batch_env(num_agents):
# TODO(blazej0): the parameters should be infered.
observ_shape = (210, 160, 3)
observ_dtype = tf.float32
action_shape = []
action_dtype = tf.int32
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(
num_agents, observ_shape, observ_dtype, action_shape, action_dtype)
def define_simulated_batch_env(environment_lambda, num_agents):
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(environment_lambda, num_agents)
return cur_batch_env
5 changes: 3 additions & 2 deletions tensor2tensor/rl/model_rl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(hparams, output_dir):
time_delta = time.time() - start_time
print(line+"Step {}.1. - generate data from policy. "
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
FLAGS.problem = "gym_discrete_problem_with_agent"
FLAGS.problem = "gym_discrete_problem_with_agent_on_{}".format(hparams.game)
FLAGS.agent_policy_path = last_model
gym_problem = registry.problem(FLAGS.problem)
gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
Expand All @@ -76,7 +76,7 @@ def train(hparams, output_dir):
print(line+"Step {}.3. - evalue env model. "
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
gym_simulated_problem = registry.problem(
"gym_simulated_discrete_problem_with_agent")
"gym_simulated_discrete_problem_with_agent_on_{}".format(hparams.game))
sim_steps = hparams.simulated_env_generator_num_steps
gym_simulated_problem.settable_num_steps = sim_steps
gym_simulated_problem.generate_data(iter_data_dir, tmp_dir)
Expand Down Expand Up @@ -115,6 +115,7 @@ def main(_):
simulated_env_generator_num_steps=300,
ppo_epochs_num=200,
ppo_epoch_length=300,
game="pong",
)
train(hparams, FLAGS.output_dir)

Expand Down

0 comments on commit 4ebb860

Please sign in to comment.