diff --git a/.gitignore b/.gitignore index 92468f5..91a0743 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +log/ tmp/ notebooks/ __pycache__/ diff --git a/conda_env.yml b/conda_env.yml index fffc1a5..7744897 100644 --- a/conda_env.yml +++ b/conda_env.yml @@ -9,6 +9,8 @@ dependencies: - absl-py - pyparsing - pillow=6.1 + - pandas + - pip - pip: - termcolor - git+git://github.com/deepmind/dm_control.git diff --git a/curl_sac.py b/curl_sac.py index 5836371..a964dda 100644 --- a/curl_sac.py +++ b/curl_sac.py @@ -8,6 +8,8 @@ import utils from encoder import make_encoder +logger = utils.get_logger(__name__) + LOG_FREQ = 10000 @@ -46,9 +48,10 @@ def weight_init(m): class Actor(nn.Module): """MLP actor network.""" + def __init__( - self, obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters ): super().__init__() @@ -70,7 +73,7 @@ def __init__( self.apply(weight_init) def forward( - self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False + self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False ): obs = self.encoder(obs, detach=detach_encoder) @@ -79,7 +82,7 @@ def forward( # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * ( - self.log_std_max - self.log_std_min + self.log_std_max - self.log_std_min ) * (log_std + 1) self.outputs['mu'] = mu @@ -116,6 +119,7 @@ def log(self, L, step, log_freq=LOG_FREQ): class QFunction(nn.Module): """MLP for q-function.""" + def __init__(self, obs_dim, action_dim, hidden_dim): super().__init__() @@ -134,13 +138,13 @@ def forward(self, obs, action): class Critic(nn.Module): """Critic network, employes two q-functions.""" + def __init__( - self, obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, num_layers, num_filters + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters ): super().__init__() - self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters, output_logits=True @@ -193,7 +197,7 @@ def __init__(self, obs_shape, z_dim, batch_size, critic, critic_target, output_t self.encoder = critic.encoder - self.encoder_target = critic_target.encoder + self.encoder_target = critic_target.encoder self.W = nn.Parameter(torch.rand(z_dim, z_dim)) self.output_type = output_type @@ -227,37 +231,39 @@ def compute_logits(self, z_a, z_pos): logits = logits - torch.max(logits, 1)[0][:, None] return logits + class CurlSacAgent(object): """CURL representation learning with SAC.""" + def __init__( - self, - obs_shape, - action_shape, - device, - hidden_dim=256, - discount=0.99, - init_temperature=0.01, - alpha_lr=1e-3, - alpha_beta=0.9, - actor_lr=1e-3, - actor_beta=0.9, - actor_log_std_min=-10, - actor_log_std_max=2, - actor_update_freq=2, - critic_lr=1e-3, - critic_beta=0.9, - critic_tau=0.005, - critic_target_update_freq=2, - encoder_type='pixel', - encoder_feature_dim=50, - encoder_lr=1e-3, - encoder_tau=0.005, - num_layers=4, - num_filters=32, - cpc_update_freq=1, - log_interval=100, - detach_encoder=False, - curl_latent_dim=128 + self, + obs_shape, + action_shape, + device, + hidden_dim=256, + discount=0.99, + init_temperature=0.01, + alpha_lr=1e-3, + alpha_beta=0.9, + actor_lr=1e-3, + actor_beta=0.9, + actor_log_std_min=-10, + actor_log_std_max=2, + actor_update_freq=2, + critic_lr=1e-3, + critic_beta=0.9, + critic_tau=0.005, + critic_target_update_freq=2, + encoder_type='pixel', + encoder_feature_dim=50, + encoder_lr=1e-3, + encoder_tau=0.005, + num_layers=4, + num_filters=32, + cpc_update_freq=1, + log_interval=100, + detach_encoder=False, + curl_latent_dim=128 ): self.device = device self.discount = discount @@ -297,7 +303,7 @@ def __init__( self.log_alpha.requires_grad = True # set target entropy to -|A| self.target_entropy = -np.prod(action_shape) - + # optimizers self.actor_optimizer = torch.optim.Adam( self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999) @@ -314,7 +320,8 @@ def __init__( if self.encoder_type == 'pixel': # create CURL encoder (the 128 batch size is probably unnecessary) self.CURL = CURL(obs_shape, encoder_feature_dim, - self.curl_latent_dim, self.critic,self.critic_target, output_type='continuous').to(self.device) + self.curl_latent_dim, self.critic, self.critic_target, output_type='continuous').to( + self.device) # optimizer for critic encoder for reconstruction loss self.encoder_optimizer = torch.optim.Adam( @@ -341,9 +348,11 @@ def alpha(self): return self.log_alpha.exp() def select_action(self, obs): + # logger.info(obs.shape) with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device) obs = obs.unsqueeze(0) + # logger.info(obs.shape) mu, _, _, _ = self.actor( obs, compute_pi=False, compute_log_pi=False ) @@ -352,7 +361,7 @@ def select_action(self, obs): def sample_action(self, obs): if obs.shape[-1] != self.image_size: obs = utils.center_crop_image(obs, self.image_size) - + with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device) obs = obs.unsqueeze(0) @@ -362,6 +371,7 @@ def sample_action(self, obs): def update_critic(self, obs, action, reward, next_obs, not_done, L, step): with torch.no_grad(): _, policy_action, log_pi, _ = self.actor(next_obs) + # logger.info((obs.shape, next_obs.shape)) target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi @@ -375,7 +385,6 @@ def update_critic(self, obs, action, reward, next_obs, not_done, L, step): if step % self.log_interval == 0: L.log('train_critic/loss', critic_loss, step) - # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() @@ -395,8 +404,8 @@ def update_actor_and_alpha(self, obs, L, step): L.log('train_actor/loss', actor_loss, step) L.log('train_actor/target_entropy', self.target_entropy, step) entropy = 0.5 * log_std.shape[1] * \ - (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) - if step % self.log_interval == 0: + (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) + if step % self.log_interval == 0: L.log('train_actor/entropy', entropy.mean(), step) # optimize the actor @@ -416,14 +425,14 @@ def update_actor_and_alpha(self, obs, L, step): self.log_alpha_optimizer.step() def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step): - + z_a = self.CURL.encode(obs_anchor) z_pos = self.CURL.encode(obs_pos, ema=True) - + logits = self.CURL.compute_logits(z_a, z_pos) labels = torch.arange(logits.shape[0]).long().to(self.device) loss = self.cross_entropy_loss(logits, labels) - + self.encoder_optimizer.zero_grad() self.cpc_optimizer.zero_grad() loss.backward() @@ -433,16 +442,16 @@ def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step): if step % self.log_interval == 0: L.log('train/curl_loss', loss, step) - def update(self, replay_buffer, L, step): if self.encoder_type == 'pixel': obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc() else: obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio() - + if step % self.log_interval == 0: L.log('train/batch_reward', reward.mean(), step) + # logger.info((obs.shape, next_obs.shape)) self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: @@ -459,10 +468,10 @@ def update(self, replay_buffer, L, step): self.critic.encoder, self.critic_target.encoder, self.encoder_tau ) - + if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel': obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"] - self.update_cpc(obs_anchor, obs_pos,cpc_kwargs, L, step) + self.update_cpc(obs_anchor, obs_pos, cpc_kwargs, L, step) def save(self, model_dir, step): torch.save( @@ -484,4 +493,3 @@ def load(self, model_dir, step): self.critic.load_state_dict( torch.load('%s/critic_%s.pt' % (model_dir, step)) ) - \ No newline at end of file diff --git a/dmc2gym/__init__.py b/dmc2gym/__init__.py new file mode 100644 index 0000000..1262625 --- /dev/null +++ b/dmc2gym/__init__.py @@ -0,0 +1,52 @@ +import gym +from gym.envs.registration import register + + +def make( + domain_name, + task_name, + seed=1, + visualize_reward=True, + from_pixels=False, + height=84, + width=84, + camera_ids=(0, ), + frame_skip=1, + episode_length=1000, + environment_kwargs=None, + time_limit=None, + channels_first=True +): + env_id = 'dmc_%s_%s-v1' % (domain_name, task_name) + + if from_pixels: + assert not visualize_reward, 'cannot use visualize reward when learning from pixels' + + # shorten episode length + max_episode_steps = (episode_length + frame_skip - 1) // frame_skip + + if not env_id in gym.envs.registry.env_specs: + task_kwargs = {} + if seed is not None: + task_kwargs['random'] = seed + if time_limit is not None: + task_kwargs['time_limit'] = time_limit + register( + id=env_id, + entry_point='dmc2gym.wrappers:DMCWrapper', + kwargs=dict( + domain_name=domain_name, + task_name=task_name, + task_kwargs=task_kwargs, + environment_kwargs=environment_kwargs, + visualize_reward=visualize_reward, + from_pixels=from_pixels, + height=height, + width=width, + camera_ids=camera_ids, + frame_skip=frame_skip, + channels_first=channels_first, + ), + max_episode_steps=max_episode_steps, + ) + return gym.make(env_id) diff --git a/dmc2gym/wrappers.py b/dmc2gym/wrappers.py new file mode 100644 index 0000000..d37f8e6 --- /dev/null +++ b/dmc2gym/wrappers.py @@ -0,0 +1,178 @@ +from gym import core, spaces +from dm_control import suite +from dm_env import specs +import numpy as np + + +def _spec_to_box(spec): + def extract_min_max(s): + assert s.dtype == np.float64 or s.dtype == np.float32 + dim = np.int(np.prod(s.shape)) + if type(s) == specs.Array: + bound = np.inf * np.ones(dim, dtype=np.float32) + return -bound, bound + elif type(s) == specs.BoundedArray: + zeros = np.zeros(dim, dtype=np.float32) + return s.minimum + zeros, s.maximum + zeros + + mins, maxs = [], [] + for s in spec: + mn, mx = extract_min_max(s) + mins.append(mn) + maxs.append(mx) + low = np.concatenate(mins, axis=0) + high = np.concatenate(maxs, axis=0) + assert low.shape == high.shape + return spaces.Box(low, high, dtype=np.float32) + + +def _flatten_obs(obs): + obs_pieces = [] + for v in obs.values(): + flat = np.array([v]) if np.isscalar(v) else v.ravel() + obs_pieces.append(flat) + return np.concatenate(obs_pieces, axis=0) + + +class DMCWrapper(core.Env): + def __init__( + self, + domain_name, + task_name, + task_kwargs=None, + visualize_reward={}, + from_pixels=False, + height=84, + width=84, + camera_ids=(0, ), + frame_skip=1, + environment_kwargs=None, + channels_first=True + ): + assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' + self._from_pixels = from_pixels + self._height = height + self._width = width + self._camera_ids = camera_ids + self._frame_skip = frame_skip + self._channels_first = channels_first + + # create task + self._env = suite.load( + domain_name=domain_name, + task_name=task_name, + task_kwargs=task_kwargs, + visualize_reward=visualize_reward, + environment_kwargs=environment_kwargs + ) + + # true and normalized action spaces + self._true_action_space = _spec_to_box([self._env.action_spec()]) + self._norm_action_space = spaces.Box( + low=-1.0, + high=1.0, + shape=self._true_action_space.shape, + dtype=np.float32 + ) + + # create observation space + if from_pixels: + shape = [3, height, width] if channels_first else [height, width, 3] + self._observation_space = spaces.Box( + low=0, high=255, shape=shape, dtype=np.uint8 + ) + else: + self._observation_space = _spec_to_box( + self._env.observation_spec().values() + ) + + self._state_space = _spec_to_box( + self._env.observation_spec().values() + ) + + self.current_state = None + + # set seed + self.seed(seed=task_kwargs.get('random', 1)) + + @property + def num_camera(self): + return len(self._camera_ids) + + def __getattr__(self, name): + return getattr(self._env, name) + + def _get_obs(self, time_step): + if self._from_pixels: + obs = [] + for camera_id in self._camera_ids: + obs.append(self.render( + height=self._height, + width=self._width, + camera_id=camera_id + )) + obs = np.stack(obs, axis=0) # H, W, C > n, H, W, C + if self._channels_first: + obs = obs.transpose(0, 3, 1, 2).copy() # C, H, W > n, C, H, W + else: + obs = _flatten_obs(time_step.observation) + return obs + + def _convert_action(self, action): + action = action.astype(np.float64) + true_delta = self._true_action_space.high - self._true_action_space.low + norm_delta = self._norm_action_space.high - self._norm_action_space.low + action = (action - self._norm_action_space.low) / norm_delta + action = action * true_delta + self._true_action_space.low + action = action.astype(np.float32) + return action + + @property + def observation_space(self): + return self._observation_space + + @property + def state_space(self): + return self._state_space + + @property + def action_space(self): + return self._norm_action_space + + def seed(self, seed): + self._true_action_space.seed(seed) + self._norm_action_space.seed(seed) + self._observation_space.seed(seed) + + def step(self, action): + assert self._norm_action_space.contains(action) + action = self._convert_action(action) + assert self._true_action_space.contains(action) + reward = 0 + extra = {'internal_state': self._env.physics.get_state().copy()} + + for _ in range(self._frame_skip): + time_step = self._env.step(action) + reward += time_step.reward or 0 + done = time_step.last() + if done: + break + obs = self._get_obs(time_step) + self.current_state = _flatten_obs(time_step.observation) + extra['discount'] = time_step.discount + return obs, reward, done, extra + + def reset(self): + time_step = self._env.reset() + self.current_state = _flatten_obs(time_step.observation) + obs = self._get_obs(time_step) + return obs + + def render(self, mode='rgb_array', height=None, width=None, camera_id=0): + assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode + height = height or self._height + width = width or self._width + camera_id = camera_id or self._camera_ids[0] + return self._env.physics.render( + height=height, width=width, camera_id=camera_id + ) diff --git a/encoder.py b/encoder.py index 9da499f..b1d8b3d 100644 --- a/encoder.py +++ b/encoder.py @@ -1,6 +1,12 @@ import torch import torch.nn as nn +import torch.nn.functional as F + +from utils import get_logger + +logger = get_logger(__name__) + def tie_weights(src, trg): assert type(src) == type(trg) @@ -16,21 +22,24 @@ def tie_weights(src, trg): class PixelEncoder(nn.Module): """Convolutional encoder of pixels observations.""" - def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False): + + def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32, output_logits=False): super().__init__() - assert len(obs_shape) == 3 + # logger.info((obs_shape, feature_dim, num_layers, num_filters, output_logits)) + + assert len(obs_shape) == 4 # n, C, H, W self.obs_shape = obs_shape self.feature_dim = feature_dim self.num_layers = num_layers self.convs = nn.ModuleList( - [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] + [nn.Conv2d(obs_shape[1], num_filters, 3, stride=2)] ) for i in range(num_layers - 1): self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) - out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[num_layers] + out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[num_layers] self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) self.ln = nn.LayerNorm(self.feature_dim) @@ -45,19 +54,32 @@ def reparameterize(self, mu, logstd): def forward_conv(self, obs): obs = obs / 255. self.outputs['obs'] = obs + # logger.info(obs.shape) conv = torch.relu(self.convs[0](obs)) self.outputs['conv1'] = conv + # logger.info(conv.shape) for i in range(1, self.num_layers): conv = torch.relu(self.convs[i](conv)) self.outputs['conv%s' % (i + 1)] = conv + # logger.info(conv.shape) h = conv.view(conv.size(0), -1) + # logger.info(h.shape) return h def forward(self, obs, detach=False): + assert obs.ndim == 5 # B,n,C,H,W + B, n, C, H, W = obs.shape + obs = obs.view(B * n, C, H, W) + # logger.info((obs.ndim, obs.shape)) + # b = obs.shape[0] + # obs = obs.view(b, -1, *self.obs_shape[-2:]) + h = self.forward_conv(obs) + h = h.view(B, n, -1).permute(0, 2, 1) # B,f,n + h = F.max_pool1d(input=h, kernel_size=n).squeeze() # B,f if detach: h = h.detach() @@ -98,7 +120,7 @@ def log(self, L, step, log_freq): class IdentityEncoder(nn.Module): - def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args): + def __init__(self, obs_shape, feature_dim, num_layers, num_filters, *args): super().__init__() assert len(obs_shape) == 1 @@ -118,7 +140,7 @@ def log(self, L, step, log_freq): def make_encoder( - encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False + encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False ): assert encoder_type in _AVAILABLE_ENCODERS return _AVAILABLE_ENCODERS[encoder_type]( diff --git a/note.txt b/note.txt new file mode 100644 index 0000000..274f13d --- /dev/null +++ b/note.txt @@ -0,0 +1,17 @@ +### dmc2gym/wrappers.py +def _get_obs(self, time_step): + if self._from_pixels: + obs = [] + for camera_id in self._camera_ids: + obs.append(self.render( + height=self._height, + width=self._width, + camera_id=camera_id + )) + obs = np.concatenate(obs, axis=-1) # H, W, C > n, H, W, C + if self._channels_first: + obs = obs.transpose(2, 0, 1).copy() # C, H, W > n, C, H, W + else: + obs = _flatten_obs(time_step.observation) + # print('wrappers:119 {}'.format(obs.shape)) + return obs \ No newline at end of file diff --git a/playground.py b/playground.py new file mode 100644 index 0000000..0c5fba1 --- /dev/null +++ b/playground.py @@ -0,0 +1,183 @@ +# # from dm_control import suite +# # import numpy as np +# # +# # # Load one task: +# # env = suite.load(domain_name="cartpole", task_name="swingup") +# # +# # # Iterate over a task set: +# # for domain_name, task_name in suite.BENCHMARKING: +# # env = suite.load(domain_name, task_name) +# # +# # # Step through an episode and print out reward, discount and observation. +# # action_spec = env.action_spec() +# # time_step = env.reset() +# # while not time_step.last(): +# # action = np.random.uniform(action_spec.minimum, +# # action_spec.maximum, +# # size=action_spec.shape) +# # time_step = env.step(action) +# # print(time_step.reward, time_step.discount, time_step.observation) +# # +# +# +# +# # from dm_control import suite +# # from dm_control import viewer +# # import numpy as np +# # +# # env = suite.load(domain_name="humanoid", task_name="stand") +# # action_spec = env.action_spec() +# # +# # # Define a uniform random policy. +# # def random_policy(time_step): +# # del time_step # Unused. +# # return np.random.uniform(low=action_spec.minimum, +# # high=action_spec.maximum, +# # size=action_spec.shape) +# # +# # # Launch the viewer application. +# # viewer.launch(env, policy=random_policy) +# +# +# from dm_control import suite +# from dm_control.suite.wrappers import pixels +# +# env = suite.load('hopper', 'hop') +# +# wrapped_env = pixels.Wrapper(env, render_kwargs={'camera_id': 'cam0'}) + +import cv2 +from dm_control import suite +from dm_control.suite.wrappers import pixels + +import numpy as np + +# Load one task: +env = suite.load(domain_name='walker', task_name='walk') +# wrapped_env = pixels.Wrapper(env, render_kwargs={'camera_id': 0}) + +# Iterate over a task set: +# for domain_name, task_name in suite.BENCHMARKING: +# env = suite.load(domain_name, task_name) + +# Step through an episode and print out reward, discount and observation. + +height, width = 480, 480 + +action_spec = env.action_spec() +time_step = env.reset() +image = env.physics.render(height, width, camera_id=0) +images = [] +for i in range(100): +# while not time_step.last(): + action = np.random.uniform(action_spec.minimum, + action_spec.maximum, + size=action_spec.shape) + time_step = env.step(action) + image1 = env.physics.render(height, width, camera_id=0) + image2 = env.physics.render(height, width, camera_id=1) + image = np.zeros((image1.shape[0], image1.shape[1] * 2, image1.shape[2]), dtype=image1.dtype) + image[:, :image1.shape[1], :] = image1.copy() + image[:, image1.shape[1]:, :] = image2.copy() + images.append(image) + # image3 = env.physics.render(height, width, camera_id=2) + + # time_step = wrapped_env.step(action) + # cv2.imshow('image', image[..., ::-1]) + # cv2.imshow('image2', image2) + # cv2.imshow('image3', image3) + # cv2.waitKey() + # print(time_step.reward, time_step.discount, time_step.observation['pixels'].shape) + + +import imageio +from pathlib import Path +video_path = Path.home() / '.curl/walker-walk.mp4' +imageio.mimsave(str(video_path), images, fps=24) + + + +# # Copyright 2017 The dm_control Authors. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# # ============================================================================ +# +# """Demonstration of amc parsing for CMU mocap database. +# +# To run the demo, supply a path to a `.amc` file: +# +# python mocap_demo --filename='path/to/mocap.amc' +# +# CMU motion capture clips are available at mocap.cs.cmu.edu +# """ +# +# from __future__ import absolute_import +# from __future__ import division +# from __future__ import print_function +# +# import time +# # Internal dependencies. +# +# from absl import app +# from absl import flags +# +# from dm_control.suite import humanoid_CMU +# from dm_control.suite.utils import parse_amc +# +# import matplotlib.pyplot as plt +# import numpy as np +# +# FLAGS = flags.FLAGS +# flags.DEFINE_string('filename', None, 'amc file to be converted.') +# flags.DEFINE_integer('max_num_frames', 90, +# 'Maximum number of frames for plotting/playback') +# +# +# def main(unused_argv): +# env = humanoid_CMU.stand() +# +# # Parse and convert specified clip. +# converted = parse_amc.convert(FLAGS.filename, +# env.physics, env.control_timestep()) +# +# max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1) +# +# width = 480 +# height = 480 +# video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8) +# +# for i in range(max_frame): +# p_i = converted.qpos[:, i] +# with env.physics.reset_context(): +# env.physics.data.qpos[:] = p_i +# video[i] = np.hstack([env.physics.render(height, width, camera_id=0), +# env.physics.render(height, width, camera_id=1)]) +# +# tic = time.time() +# for i in range(max_frame): +# if i == 0: +# img = plt.imshow(video[i]) +# else: +# img.set_data(video[i]) +# toc = time.time() +# clock_dt = toc - tic +# tic = time.time() +# # Real-time playback not always possible as clock_dt > .03 +# plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0. +# plt.draw() +# plt.waitforbuttonpress() +# +# +# if __name__ == '__main__': +# flags.mark_flag_as_required('filename') +# app.run(main) \ No newline at end of file diff --git a/train.py b/train.py index 3768b31..5f33a52 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,8 @@ from curl_sac import CurlSacAgent from torchvision import transforms +logger = utils.get_logger(__name__) + def parse_args(): parser = argparse.ArgumentParser() @@ -31,6 +33,8 @@ def parse_args(): parser.add_argument('--frame_stack', default=3, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=100000, type=int) + # camera + parser.add_argument('--camera_ids', nargs='+', default=[0], type=int) # train parser.add_argument('--agent', default='curl_sac', type=str) parser.add_argument('--init_steps', default=1000, type=int) @@ -38,13 +42,14 @@ def parse_args(): parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--hidden_dim', default=1024, type=int) # eval - parser.add_argument('--eval_freq', default=1000, type=int) + parser.add_argument('--eval_freq', default=10000, type=int) parser.add_argument('--num_eval_episodes', default=10, type=int) # critic parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--critic_beta', default=0.9, type=float) - parser.add_argument('--critic_tau', default=0.01, type=float) # try 0.05 or 0.1 - parser.add_argument('--critic_target_update_freq', default=2, type=int) # try to change it to 1 and retain 0.01 above + parser.add_argument('--critic_tau', default=0.01, type=float) # try 0.05 or 0.1 + parser.add_argument('--critic_target_update_freq', default=2, + type=int) # try to change it to 1 and retain 0.01 above # actor parser.add_argument('--actor_lr', default=1e-3, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) @@ -74,6 +79,7 @@ def parse_args(): parser.add_argument('--detach_encoder', default=False, action='store_true') parser.add_argument('--log_interval', default=100, type=int) + parser.add_argument('--custom_name', default='', type=str) args = parser.parse_args() return args @@ -92,7 +98,7 @@ def run_eval_loop(sample_stochastically=True): while not done: # center crop image if args.encoder_type == 'pixel': - obs = utils.center_crop_image(obs,args.image_size) + obs = utils.center_crop_image(obs, args.image_size) with utils.eval_mode(agent): if sample_stochastically: action = agent.sample_action(obs) @@ -105,8 +111,8 @@ def run_eval_loop(sample_stochastically=True): video.save('%d.mp4' % step) L.log('eval/' + prefix + 'episode_reward', episode_reward, step) all_ep_rewards.append(episode_reward) - - L.log('eval/' + prefix + 'eval_time', time.time()-start_time , step) + + L.log('eval/' + prefix + 'eval_time', time.time() - start_time, step) mean_ep_reward = np.mean(all_ep_rewards) best_ep_reward = np.max(all_ep_rewards) L.log('eval/' + prefix + 'mean_episode_reward', mean_ep_reward, step) @@ -145,15 +151,15 @@ def make_agent(obs_shape, action_shape, args, device): log_interval=args.log_interval, detach_encoder=args.detach_encoder, curl_latent_dim=args.curl_latent_dim - ) else: assert 'agent is not supported: %s' % args.agent + def main(): args = parse_args() - if args.seed == -1: - args.__dict__["seed"] = np.random.randint(1,1000000) + if args.seed == -1: + args.__dict__["seed"] = np.random.randint(1, 1000000) utils.set_seed_everywhere(args.seed) env = dmc2gym.make( domain_name=args.domain_name, @@ -163,29 +169,38 @@ def main(): from_pixels=(args.encoder_type == 'pixel'), height=args.pre_transform_image_size, width=args.pre_transform_image_size, - frame_skip=args.action_repeat + frame_skip=args.action_repeat, + camera_ids=args.camera_ids, ) - env.seed(args.seed) # stack several consecutive frames together if args.encoder_type == 'pixel': env = utils.FrameStack(env, k=args.frame_stack) - + # make directory - ts = time.gmtime() - ts = time.strftime("%m-%d", ts) + ts = time.gmtime() + ts = time.strftime("%m-%d", ts) env_name = args.domain_name + '-' + args.task_name - exp_name = env_name + '-' + ts + '-im' + str(args.image_size) +'-b' \ - + str(args.batch_size) + '-s' + str(args.seed) + '-' + args.encoder_type - args.work_dir = args.work_dir + '/' + exp_name + exp_strs = [env_name, ts, + 'im{}'.format(args.image_size), + 'b{}'.format(args.batch_size), + 's{}'.format(args.seed), + 'c{}'.format('_'.join([str(v) for v in args.camera_ids])), + args.encoder_type] + if args.custom_name: + exp_strs.append(args.custom_name) + exp_name = '-'.join(exp_strs) + logger.info(exp_name) + + args.work_dir = args.work_dir + '/' + exp_name utils.make_dir(args.work_dir) video_dir = utils.make_dir(os.path.join(args.work_dir, 'video')) model_dir = utils.make_dir(os.path.join(args.work_dir, 'model')) buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer')) - video = VideoRecorder(video_dir if args.save_video else None) + video = VideoRecorder(video_dir if args.save_video else None, camera_ids=args.camera_ids) with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: json.dump(vars(args), f, sort_keys=True, indent=4) @@ -195,8 +210,9 @@ def main(): action_shape = env.action_space.shape if args.encoder_type == 'pixel': - obs_shape = (3*args.frame_stack, args.image_size, args.image_size) - pre_aug_obs_shape = (3*args.frame_stack,args.pre_transform_image_size,args.pre_transform_image_size) + obs_shape = (env.num_camera, 3 * args.frame_stack, args.image_size, args.image_size) + pre_aug_obs_shape = (env.num_camera, 3 * args.frame_stack, args.pre_transform_image_size, + args.pre_transform_image_size) else: obs_shape = env.observation_space.shape pre_aug_obs_shape = obs_shape @@ -227,7 +243,7 @@ def main(): if step % args.eval_freq == 0: L.log('eval/episode', episode, step) - evaluate(env, agent, video, args.num_eval_episodes, L, step,args) + evaluate(env, agent, video, args.num_eval_episodes, L, step, args) if args.save_model: agent.save_curl(model_dir, step) if args.save_buffer: @@ -243,6 +259,7 @@ def main(): L.log('train/episode_reward', episode_reward, step) obs = env.reset() + # logger.info(obs.shape) done = False episode_reward = 0 episode_step = 0 @@ -255,11 +272,12 @@ def main(): action = env.action_space.sample() else: with utils.eval_mode(agent): + # logger.info(obs.shape) action = agent.sample_action(obs) # run training update if step >= args.init_steps: - num_updates = 1 + num_updates = 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) diff --git a/utils.py b/utils.py index 6603caf..4ddfd7c 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,49 @@ from torch.utils.data import Dataset, DataLoader import time from skimage.util.shape import view_as_windows +from logging import Formatter, DEBUG, getLogger, StreamHandler + +logger = None + + +class MyFormatter(Formatter): + width = 50 + + def format(self, record): + width = 50 + datefmt = '%H:%M:%S' + cpath = '%s:%s:%s' % (record.module, record.funcName, record.lineno) + cpath = cpath[-width:].ljust(width) + record.message = record.getMessage() + s = "[%s - %s] %s" % (self.formatTime(record, datefmt), cpath, record.getMessage()) + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if s[-1:] != "\n": + s = s + "\n" + s = s + record.exc_text + return s + + +def get_logger(name): + global logger + if logger is None: + LEVEL = DEBUG + logger = getLogger(name) + logger.setLevel(LEVEL) + ch = StreamHandler() + ch.setLevel(LEVEL) + formatter = MyFormatter() + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger + + +logger = get_logger(__name__) + class eval_mode(object): def __init__(self, *models): @@ -57,10 +100,10 @@ def make_dir(dir_path): def preprocess_obs(obs, bits=5): """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" - bins = 2**bits + bins = 2 ** bits assert obs.dtype == torch.float32 if bits < 8: - obs = torch.floor(obs / 2**(8 - bits)) + obs = torch.floor(obs / 2 ** (8 - bits)) obs = obs / bins obs = obs + torch.rand_like(obs) / bins obs = obs - 0.5 @@ -69,7 +112,8 @@ def preprocess_obs(obs, bits=5): class ReplayBuffer(Dataset): """Buffer to store environment transitions.""" - def __init__(self, obs_shape, action_shape, capacity, batch_size, device,image_size=84,transform=None): + + def __init__(self, obs_shape, action_shape, capacity, batch_size, device, image_size=84, transform=None): self.capacity = capacity self.batch_size = batch_size self.device = device @@ -77,7 +121,7 @@ def __init__(self, obs_shape, action_shape, capacity, batch_size, device,image_s self.transform = transform # the proprioceptive obs is stored as float32, pixels obs as uint8 obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 - + self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.actions = np.empty((capacity, *action_shape), dtype=np.float32) @@ -88,11 +132,8 @@ def __init__(self, obs_shape, action_shape, capacity, batch_size, device,image_s self.last_save = 0 self.full = False - - - def add(self, obs, action, reward, next_obs, done): - + np.copyto(self.obses[self.idx], obs) np.copyto(self.actions[self.idx], action) np.copyto(self.rewards[self.idx], reward) @@ -103,11 +144,11 @@ def add(self, obs, action, reward, next_obs, done): self.full = self.full or self.idx == 0 def sample_proprio(self): - + idxs = np.random.randint( 0, self.capacity if self.full else self.idx, size=self.batch_size ) - + obses = self.obses[idxs] next_obses = self.next_obses[idxs] @@ -126,15 +167,16 @@ def sample_cpc(self): idxs = np.random.randint( 0, self.capacity if self.full else self.idx, size=self.batch_size ) - + obses = self.obses[idxs] next_obses = self.next_obses[idxs] pos = obses.copy() + # logger.info((obses.shape, next_obses.shape)) obses = random_crop(obses, self.image_size) next_obses = random_crop(next_obses, self.image_size) pos = random_crop(pos, self.image_size) - + obses = torch.as_tensor(obses, device=self.device).float() next_obses = torch.as_tensor( next_obses, device=self.device @@ -196,7 +238,8 @@ def __getitem__(self, idx): return obs, action, reward, next_obs, not_done def __len__(self): - return self.capacity + return self.capacity + class FrameStack(gym.Wrapper): def __init__(self, env, k): @@ -225,7 +268,8 @@ def step(self, action): def _get_obs(self): assert len(self._frames) == self._k - return np.concatenate(list(self._frames), axis=0) + obs = np.concatenate(list(self._frames), axis=1) + return obs def random_crop(imgs, output_size): @@ -237,28 +281,40 @@ def random_crop(imgs, output_size): imgs, batch images with shape (B,C,H,W) """ # batch size - n = imgs.shape[0] - img_size = imgs.shape[-1] + # logger.info(imgs.shape) + assert imgs.ndim == 5 + B, n, C, H, W = imgs.shape + assert H == W + img_size = W crop_max = img_size - output_size - imgs = np.transpose(imgs, (0, 2, 3, 1)) - w1 = np.random.randint(0, crop_max, n) - h1 = np.random.randint(0, crop_max, n) + # logger.info(imgs.shape) + imgs = np.transpose(imgs, (0, 1, 3, 4, 2)) # (B,n,H,W,C) + # logger.info(imgs.shape) + imgs = imgs.reshape((B * n, H, W, C)) + # logger.info(imgs.shape) + w1 = np.random.randint(0, crop_max, B * n) + h1 = np.random.randint(0, crop_max, B * n) # creates all sliding windows combinations of size (output_size) windows = view_as_windows( - imgs, (1, output_size, output_size, 1))[..., 0,:,:, 0] + imgs, (1, output_size, output_size, 1))[..., 0, :, :, 0] # selects a random window for each batch element - cropped_imgs = windows[np.arange(n), w1, h1] + # logger.info(windows.shape) + cropped_imgs = windows[np.arange(B * n), w1, h1] + # logger.info(cropped_imgs.shape) + cropped_imgs = cropped_imgs.reshape((B, n, C, *cropped_imgs.shape[-2:])) + # logger.info(cropped_imgs.shape) return cropped_imgs + def center_crop_image(image, output_size): - h, w = image.shape[1:] + assert image.ndim == 4 + # logger.info(image.shape) + n, c, h, w = image.shape new_h, new_w = output_size, output_size - top = (h - new_h)//2 - left = (w - new_w)//2 + top = (h - new_h) // 2 + left = (w - new_w) // 2 - image = image[:, top:top + new_h, left:left + new_w] + image = image[..., top:top + new_h, left:left + new_w] + # logger.info(image.shape) return image - - - diff --git a/video.py b/video.py index 0e319f8..4335e50 100644 --- a/video.py +++ b/video.py @@ -4,11 +4,11 @@ class VideoRecorder(object): - def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): + def __init__(self, dir_name, height=256, width=256, camera_ids=(0, ), fps=30): self.dir_name = dir_name self.height = height self.width = width - self.camera_id = camera_id + self.camera_ids = camera_ids self.fps = fps self.frames = [] @@ -18,18 +18,21 @@ def init(self, enabled=True): def record(self, env): if self.enabled: - try: - frame = env.render( - mode='rgb_array', - height=self.height, - width=self.width, - camera_id=self.camera_id - ) - except: - frame = env.render( - mode='rgb_array', - ) - + views = [] + for camera_id in self.camera_ids: + try: + frame = env.render( + mode='rgb_array', + height=self.height, + width=self.width, + camera_id=camera_id + ) + except: + frame = env.render( + mode='rgb_array', + ) + views.append(frame) + frame = np.concatenate(views, axis=1) self.frames.append(frame) def save(self, file_name): diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..c9d5dd8 --- /dev/null +++ b/visualize.py @@ -0,0 +1,141 @@ +import re +from collections import namedtuple, defaultdict +from itertools import product +from pathlib import Path +from typing import List, Dict, Any + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + + +def mkdir_if_not_exists(path: Path) -> Path: + if not path.exists(): + path.mkdir(parents=True) + return path + + +def filter_single_item(line: str) -> List[str]: + return re.sub(r'(:|,|\"|\{|\})', '', line).split() + + +class DirectoryManager: + def __init__(self, root_dir: Path = Path.home() / '.curl'): + self.root_dir = root_dir + + @property + def figure_root_dir(self): + return mkdir_if_not_exists(self.root_dir / 'figures') + + def figure_dir(self, exp_name: str): + return mkdir_if_not_exists(self.figure_root_dir / exp_name) + + @property + def log_root_dir(self): + return mkdir_if_not_exists(self.root_dir / 'logs') + + def log_dir(self, exp_name: str): + return mkdir_if_not_exists(self.log_root_dir / exp_name) + + +class ModeInformation: + def __init__(self, mode: str, query: str, key_list: List[str]): + self.mode = mode + self.query = query + self.key_list = key_list + + @property + def log_filename(self): + return '{}.log'.format(self.mode) + + @property + def item_cls(self): + return namedtuple('{}item'.format(self.mode), self.key_list) + + +dm = DirectoryManager() + + +class Visualizer(ModeInformation): + def __init__(self, exp_name: str, mode: str, query: str, key_list: List[str]): + ModeInformation.__init__(self, mode, query, key_list) + self.exp_name = exp_name + + def visualize(self): + log_files = self.collect_log_files() + data_dict = {f: self.read_log_file(f) for f in log_files} + self.visualize_data(data_dict) + + @property + def log_dir(self) -> Path: + return dm.log_dir(self.exp_name) + + def collect_log_files(self): + return self._collect_log_files(self.log_dir) + + def _collect_log_files(self, p): + paths = [] + candidates = p.glob('*') + for p in candidates: + if p.is_dir(): + paths += self._collect_log_files(p) + elif p.name == self.log_filename: + paths.append(p) + return paths + + def read_log_file(self, in_path: Path): + with open(str(in_path), 'r') as file: + lines = file.read().splitlines() + wgl = list(map(filter_single_item, lines)) + items = [] + for words in wgl: + items.append({key: float(value) for key, value in zip(words[::2], words[1::2])}) + items = list(filter(lambda x: all(k in x for k in self.key_list), items)) + items = list(filter(lambda x: x['step'] <= 100000, items)) + items = list(map(lambda x: self.item_cls(*[x[key] for key in self.key_list]), items)) + return items + + def visualize_data(self, data_dict: Dict[Path, List[Any]]): + columns = ['model', 'index'] + self.key_list + items = [] + index_dict = defaultdict(int) + for key, value in data_dict.items(): + model = key.parent.parent.stem + index = index_dict[model] + for item in value: + items.append((model, index, *item)) + index_dict[model] += 1 + + data_frame = pd.DataFrame(items, columns=columns) + + xticks = list(range(2, 11, 2)) + xvalues, xnames = zip(*[(i * 10000, '{}k'.format(10 * i)) for i in xticks]) + + sns.set(style="ticks", color_codes=True) + sns.relplot(x='step', y=self.query, kind='line', hue='model', ci='sd', data=data_frame) + plt.ylabel(self.query.replace('_', ' ')) + plt.xlabel('steps') + plt.xticks(list(xvalues), list(xnames)) + plt.savefig(str(dm.figure_root_dir / '{}_{}.png'.format(self.exp_name, self.mode))) + plt.close() + + +def fetch_visualizer(exp_name: str, mode: str) -> Visualizer: + if mode == 'train': + query = 'episode_reward' + key_list = ['episode_reward', 'episode', 'duration', 'step'] + elif mode == 'eval': + query = 'mean_episode_reward' + key_list = ['episode_reward', 'episode', 'mean_episode_reward', 'best_episode_reward', 'step'] + else: + raise TypeError('invalid mode: {}'.format(mode)) + return Visualizer(exp_name, mode, query, key_list) + + +if __name__ == '__main__': + exp_names = ['cartpole-swingup', 'walker-walk'] + modes = ['train', 'eval'] + + for exp_name, mode in product(exp_names, modes): + visualizer = fetch_visualizer(exp_name, mode) + visualizer.visualize()