Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dual camera stack #8

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
log/
tmp/
notebooks/
__pycache__/
Expand Down
2 changes: 2 additions & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies:
- absl-py
- pyparsing
- pillow=6.1
- pandas
- pip
- pip:
- termcolor
- git+git://github.com/deepmind/dm_control.git
Expand Down
108 changes: 58 additions & 50 deletions curl_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import utils
from encoder import make_encoder

logger = utils.get_logger(__name__)

LOG_FREQ = 10000


Expand Down Expand Up @@ -46,9 +48,10 @@ def weight_init(m):

class Actor(nn.Module):
"""MLP actor network."""

def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
):
super().__init__()

Expand All @@ -70,7 +73,7 @@ def __init__(
self.apply(weight_init)

def forward(
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
):
obs = self.encoder(obs, detach=detach_encoder)

Expand All @@ -79,7 +82,7 @@ def forward(
# constrain log_std inside [log_std_min, log_std_max]
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
self.log_std_max - self.log_std_min
) * (log_std + 1)

self.outputs['mu'] = mu
Expand Down Expand Up @@ -116,6 +119,7 @@ def log(self, L, step, log_freq=LOG_FREQ):

class QFunction(nn.Module):
"""MLP for q-function."""

def __init__(self, obs_dim, action_dim, hidden_dim):
super().__init__()

Expand All @@ -134,13 +138,13 @@ def forward(self, obs, action):

class Critic(nn.Module):
"""Critic network, employes two q-functions."""

def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters
):
super().__init__()


self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters, output_logits=True
Expand Down Expand Up @@ -193,7 +197,7 @@ def __init__(self, obs_shape, z_dim, batch_size, critic, critic_target, output_t

self.encoder = critic.encoder

self.encoder_target = critic_target.encoder
self.encoder_target = critic_target.encoder

self.W = nn.Parameter(torch.rand(z_dim, z_dim))
self.output_type = output_type
Expand Down Expand Up @@ -227,37 +231,39 @@ def compute_logits(self, z_a, z_pos):
logits = logits - torch.max(logits, 1)[0][:, None]
return logits


class CurlSacAgent(object):
"""CURL representation learning with SAC."""

def __init__(
self,
obs_shape,
action_shape,
device,
hidden_dim=256,
discount=0.99,
init_temperature=0.01,
alpha_lr=1e-3,
alpha_beta=0.9,
actor_lr=1e-3,
actor_beta=0.9,
actor_log_std_min=-10,
actor_log_std_max=2,
actor_update_freq=2,
critic_lr=1e-3,
critic_beta=0.9,
critic_tau=0.005,
critic_target_update_freq=2,
encoder_type='pixel',
encoder_feature_dim=50,
encoder_lr=1e-3,
encoder_tau=0.005,
num_layers=4,
num_filters=32,
cpc_update_freq=1,
log_interval=100,
detach_encoder=False,
curl_latent_dim=128
self,
obs_shape,
action_shape,
device,
hidden_dim=256,
discount=0.99,
init_temperature=0.01,
alpha_lr=1e-3,
alpha_beta=0.9,
actor_lr=1e-3,
actor_beta=0.9,
actor_log_std_min=-10,
actor_log_std_max=2,
actor_update_freq=2,
critic_lr=1e-3,
critic_beta=0.9,
critic_tau=0.005,
critic_target_update_freq=2,
encoder_type='pixel',
encoder_feature_dim=50,
encoder_lr=1e-3,
encoder_tau=0.005,
num_layers=4,
num_filters=32,
cpc_update_freq=1,
log_interval=100,
detach_encoder=False,
curl_latent_dim=128
):
self.device = device
self.discount = discount
Expand Down Expand Up @@ -297,7 +303,7 @@ def __init__(
self.log_alpha.requires_grad = True
# set target entropy to -|A|
self.target_entropy = -np.prod(action_shape)

# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
Expand All @@ -314,7 +320,8 @@ def __init__(
if self.encoder_type == 'pixel':
# create CURL encoder (the 128 batch size is probably unnecessary)
self.CURL = CURL(obs_shape, encoder_feature_dim,
self.curl_latent_dim, self.critic,self.critic_target, output_type='continuous').to(self.device)
self.curl_latent_dim, self.critic, self.critic_target, output_type='continuous').to(
self.device)

# optimizer for critic encoder for reconstruction loss
self.encoder_optimizer = torch.optim.Adam(
Expand All @@ -341,9 +348,11 @@ def alpha(self):
return self.log_alpha.exp()

def select_action(self, obs):
# logger.info(obs.shape)
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
# logger.info(obs.shape)
mu, _, _, _ = self.actor(
obs, compute_pi=False, compute_log_pi=False
)
Expand All @@ -352,7 +361,7 @@ def select_action(self, obs):
def sample_action(self, obs):
if obs.shape[-1] != self.image_size:
obs = utils.center_crop_image(obs, self.image_size)

with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
Expand All @@ -362,6 +371,7 @@ def sample_action(self, obs):
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
_, policy_action, log_pi, _ = self.actor(next_obs)
# logger.info((obs.shape, next_obs.shape))
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
target_V = torch.min(target_Q1,
target_Q2) - self.alpha.detach() * log_pi
Expand All @@ -375,7 +385,6 @@ def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
if step % self.log_interval == 0:
L.log('train_critic/loss', critic_loss, step)


# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
Expand All @@ -395,8 +404,8 @@ def update_actor_and_alpha(self, obs, L, step):
L.log('train_actor/loss', actor_loss, step)
L.log('train_actor/target_entropy', self.target_entropy, step)
entropy = 0.5 * log_std.shape[1] * \
(1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
if step % self.log_interval == 0:
(1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
if step % self.log_interval == 0:
L.log('train_actor/entropy', entropy.mean(), step)

# optimize the actor
Expand All @@ -416,14 +425,14 @@ def update_actor_and_alpha(self, obs, L, step):
self.log_alpha_optimizer.step()

def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step):

z_a = self.CURL.encode(obs_anchor)
z_pos = self.CURL.encode(obs_pos, ema=True)

logits = self.CURL.compute_logits(z_a, z_pos)
labels = torch.arange(logits.shape[0]).long().to(self.device)
loss = self.cross_entropy_loss(logits, labels)

self.encoder_optimizer.zero_grad()
self.cpc_optimizer.zero_grad()
loss.backward()
Expand All @@ -433,16 +442,16 @@ def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step):
if step % self.log_interval == 0:
L.log('train/curl_loss', loss, step)


def update(self, replay_buffer, L, step):
if self.encoder_type == 'pixel':
obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc()
else:
obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio()

if step % self.log_interval == 0:
L.log('train/batch_reward', reward.mean(), step)

# logger.info((obs.shape, next_obs.shape))
self.update_critic(obs, action, reward, next_obs, not_done, L, step)

if step % self.actor_update_freq == 0:
Expand All @@ -459,10 +468,10 @@ def update(self, replay_buffer, L, step):
self.critic.encoder, self.critic_target.encoder,
self.encoder_tau
)

if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel':
obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"]
self.update_cpc(obs_anchor, obs_pos,cpc_kwargs, L, step)
self.update_cpc(obs_anchor, obs_pos, cpc_kwargs, L, step)

def save(self, model_dir, step):
torch.save(
Expand All @@ -484,4 +493,3 @@ def load(self, model_dir, step):
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)

52 changes: 52 additions & 0 deletions dmc2gym/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import gym
from gym.envs.registration import register


def make(
domain_name,
task_name,
seed=1,
visualize_reward=True,
from_pixels=False,
height=84,
width=84,
camera_ids=(0, ),
frame_skip=1,
episode_length=1000,
environment_kwargs=None,
time_limit=None,
channels_first=True
):
env_id = 'dmc_%s_%s-v1' % (domain_name, task_name)

if from_pixels:
assert not visualize_reward, 'cannot use visualize reward when learning from pixels'

# shorten episode length
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip

if not env_id in gym.envs.registry.env_specs:
task_kwargs = {}
if seed is not None:
task_kwargs['random'] = seed
if time_limit is not None:
task_kwargs['time_limit'] = time_limit
register(
id=env_id,
entry_point='dmc2gym.wrappers:DMCWrapper',
kwargs=dict(
domain_name=domain_name,
task_name=task_name,
task_kwargs=task_kwargs,
environment_kwargs=environment_kwargs,
visualize_reward=visualize_reward,
from_pixels=from_pixels,
height=height,
width=width,
camera_ids=camera_ids,
frame_skip=frame_skip,
channels_first=channels_first,
),
max_episode_steps=max_episode_steps,
)
return gym.make(env_id)
Loading