diff --git a/scripts/configs/hprior_awac/adroit.yaml b/scripts/configs/hprior_awac/adroit.yaml new file mode 100644 index 0000000..4b68802 --- /dev/null +++ b/scripts/configs/hprior_awac/adroit.yaml @@ -0,0 +1,105 @@ +algorithm: + class: Hindsight_PRIOR_AWAC + beta: 0.3333 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 100 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: pen-cloned-v1 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256, 256] + reward_act: identity + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + +rm_dataset: + - class: D4RLOfflineDataset + env: pen-cloned-v1 + batch_size: 64 + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: pen-cloned-v1 + batch_size: 8 + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: pen-cloned-v1 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 500000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: pen-cloned-v1 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 500000 + +processor: null diff --git a/scripts/configs/hprior_awac/gym.yaml b/scripts/configs/hprior_awac/gym.yaml new file mode 100644 index 0000000..f6dbf1c --- /dev/null +++ b/scripts/configs/hprior_awac/gym.yaml @@ -0,0 +1,106 @@ +algorithm: + class: Hindsight_PRIOR_AWAC + beta: 0.3333 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 100 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: hopper-medium-replay-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256] + reward_act: sigmoid + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + +rm_dataset: + - class: D4RLOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 64 + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 8 + segment_length: null + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 1000000 + +processor: null diff --git a/scripts/configs/hprior_awac/metaworld.yaml b/scripts/configs/hprior_awac/metaworld.yaml new file mode 100644 index 0000000..3fef49f --- /dev/null +++ b/scripts/configs/hprior_awac/metaworld.yaml @@ -0,0 +1,104 @@ +algorithm: + class: Hindsight_PRIOR_AWAC + beta: 0.3333 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 25 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: button-press-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256, 256] + reward_act: identity + ortho_init: true + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + +rm_dataset: + - class: MetaworldOfflineDataset + env: button-press-v2 + batch_size: 16 + capacity: 5000 + - class: MetaworldComparisonDataset + env: button-press-v2 + batch_size: 16 + segment_length: null + capacity: 500 +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: MetaworldOfflineDataset + env: button-press-v2 + batch_size: 16 + capacity: 5000 +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 500000 + log_freq: 500 + profile_freq: 500 + eval_freq: 10000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: MetaworldComparisonDataset + env: button-press-v2 + batch_size: 32 + segment_length: null + capacity: 500 +rl_eval: + function: eval_offline + num_ep: 20 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 500000 + +processor: null diff --git a/scripts/configs/hprior_iql/adroit.yaml b/scripts/configs/hprior_iql/adroit.yaml new file mode 100644 index 0000000..dfdaede --- /dev/null +++ b/scripts/configs/hprior_iql/adroit.yaml @@ -0,0 +1,110 @@ +algorithm: + class: Hindsight_PRIOR_IQL + beta: 0.3333 + expectile: 0.75 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 100 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: pen-cloned-v1 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256, 256] + reward_act: identity + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256, 256] + +rm_dataset: + - class: D4RLOfflineDataset + env: pen-cloned-v1 + batch_size: 64 + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: pen-cloned-v1 + batch_size: 8 + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: pen-cloned-v1 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 500000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: pen-cloned-v1 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 500000 + +processor: null diff --git a/scripts/configs/hprior_iql/gym.yaml b/scripts/configs/hprior_iql/gym.yaml new file mode 100644 index 0000000..27f2beb --- /dev/null +++ b/scripts/configs/hprior_iql/gym.yaml @@ -0,0 +1,111 @@ +algorithm: + class: Hindsight_PRIOR_IQL + beta: 0.3333 + expectile: 0.7 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 100 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: hopper-medium-replay-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256] + reward_act: sigmoid + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256] + +rm_dataset: + - class: D4RLOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 64 + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 8 + segment_length: null + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: hopper-medium-replay-v2 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 1000000 + +processor: null diff --git a/scripts/configs/hprior_iql/metaworld.yaml b/scripts/configs/hprior_iql/metaworld.yaml new file mode 100644 index 0000000..1fb00dd --- /dev/null +++ b/scripts/configs/hprior_iql/metaworld.yaml @@ -0,0 +1,109 @@ +algorithm: + class: Hindsight_PRIOR_IQL + beta: 0.3333 + expectile: 0.75 + max_exp_clip: 100.0 + reward_reg: 0.0 + max_seq_len: 25 + world_steps: 10000 + prior_coef: 1000 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: button-press-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + world: + embed_dim: 256 + num_layers: 3 + num_heads: 1 + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256, 256] + reward_act: identity + ortho_init: true + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256, 256] + +rm_dataset: + - class: MetaworldOfflineDataset + env: button-press-v2 + batch_size: 16 + capacity: 5000 + - class: MetaworldComparisonDataset + env: button-press-v2 + batch_size: 16 + segment_length: null + capacity: 500 +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: MetaworldOfflineDataset + env: button-press-v2 + batch_size: 16 + capacity: 5000 +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 60000 + rl_steps: 500000 + log_freq: 500 + profile_freq: 500 + eval_freq: 10000 + +rm_eval: + function: eval_world_model_and_reward_model + eval_dataset_kwargs: + class: MetaworldComparisonDataset + env: button-press-v2 + batch_size: 32 + segment_length: null + capacity: 500 +rl_eval: + function: eval_offline + num_ep: 20 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 500000 + +processor: null diff --git a/wiserl/algorithm/__init__.py b/wiserl/algorithm/__init__.py index bc92a28..0989bfc 100644 --- a/wiserl/algorithm/__init__.py +++ b/wiserl/algorithm/__init__.py @@ -4,6 +4,8 @@ from wiserl.algorithm.cpl_kl import CPL_KL from wiserl.algorithm.hpl.hpl import HindsightPreferenceLearning from wiserl.algorithm.hpl.hpl_awac import HindsightPreferenceLearningAWAC +from wiserl.algorithm.hprior.hprior_iql import Hindsight_PRIOR_IQL +from wiserl.algorithm.hprior.hprior_awac import Hindsight_PRIOR_AWAC from wiserl.algorithm.ipl.ipl_awac import IPL_AWAC from wiserl.algorithm.ipl.ipl_iql import IPL_IQL from wiserl.algorithm.oracle_awac import OracleAWAC diff --git a/wiserl/algorithm/hprior/hprior_awac.py b/wiserl/algorithm/hprior/hprior_awac.py new file mode 100644 index 0000000..bf2e804 --- /dev/null +++ b/wiserl/algorithm/hprior/hprior_awac.py @@ -0,0 +1,224 @@ +import itertools +import os +from operator import itemgetter +from typing import Any, Dict, Optional, Type + +import torch +import torch.nn as nn + +import wiserl.module +from wiserl.algorithm.oracle_awac import OracleAWAC +from wiserl.module.net.attention.twm import TransformerBasedWorldModel +from wiserl.utils.misc import sync_target + + +class Hindsight_PRIOR_AWAC(OracleAWAC): + def __init__( + self, + *args, + beta: float = 0.3333, + max_exp_clip: float = 100.0, + discount: float = 0.99, + tau: float = 0.005, + target_freq: int = 1, + reward_reg: float = 0.0, + max_seq_len: int = 100, + world_steps: int = 50000, + prior_coef: float = 1.0, + rm_label: bool = True, + **kwargs + ) -> None: + self.max_seq_len = max_seq_len + super().__init__( + *args, + beta=beta, + max_exp_clip=max_exp_clip, + discount=discount, + tau=tau, + target_freq=target_freq, + **kwargs + ) + self.reward_reg = reward_reg + self.prior_coef = prior_coef + self.world_steps = world_steps + self.rm_label = rm_label + assert rm_label + self.obs_dim = self.observation_space.shape[0] + self.action_dim = self.action_space.shape[0] + + self.world_criterion = torch.nn.L1Loss(reduction="none") + self.reward_criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + self.prior_criterion = torch.nn.MSELoss(reduction="none") + + def setup_network(self, network_kwargs): + super().setup_network(network_kwargs) + # world model + world_kwargs = network_kwargs["world"] + world = TransformerBasedWorldModel( + obs_dim=self.observation_space.shape[0], + action_dim=self.action_space.shape[0], + embed_dim=world_kwargs["embed_dim"], + num_layers=world_kwargs["num_layers"], + seq_len=self.max_seq_len, + num_heads=world_kwargs["num_heads"], + ) + self.network["world"] = world + # reward model + reward_act = { + "identity": nn.Identity(), + "sigmoid": nn.Sigmoid(), + }.get(network_kwargs["reward"].pop("reward_act")) + reward = vars(wiserl.module)[network_kwargs["reward"].pop("class")]( + input_dim=self.observation_space.shape[0]+self.action_space.shape[0], + output_dim=1, + **network_kwargs["reward"] + ) + self.network["reward"] = nn.Sequential(self.network["encoder"], reward, reward_act) + + def setup_optimizers(self, optim_kwargs): + super().setup_optimizers(optim_kwargs) + default_kwargs = optim_kwargs.get("default", {}) + for attr in ["world", "reward"]: + kwargs = default_kwargs.copy() + kwargs.update(optim_kwargs.get(attr, {})) + optim = vars(torch.optim)[kwargs.pop("class")]( + self.network.__getattr__(attr).parameters(), **kwargs + ) + self.optim[attr] = optim + + def select_action(self, batch, deterministic: bool=True): + return super().select_action(batch, deterministic) + + def select_reward(self, batch, deterministic=False): + obs, action = batch["obs"], batch["action"] + reward = self.network.reward(torch.concat([obs, action], dim=-1)) + return reward.mean(0).detach() + + def pretrain_step(self, batches, step: int, total_steps: int) -> Dict: + traj_batch, pref_batch = batches + if step < self.world_steps: + return self.update_world( + obs=traj_batch["obs"], + action=traj_batch["action"], + next_obs=traj_batch["next_obs"], + ) + else: + return self.update_reward( + obs_1=pref_batch["obs_1"], + obs_2=pref_batch["obs_2"], + action_1=pref_batch["action_1"], + action_2=pref_batch["action_2"], + label=pref_batch["label"], + ) + + def update_world(self, obs, action, next_obs) -> Dict: + # obs [F_B, F_S, obs_dim], action [F_B, F_S, action_dim] + F_B, F_S = obs.shape[:2] + timestep = torch.arange(F_S, device=self.device).unsqueeze(0).expand(F_B, -1) + pred_obs, attentions = self.network.world(obs, action, timestep) + world_loss = self.world_criterion(pred_obs, next_obs).sum(0).mean() + + self.optim["world"].zero_grad() + world_loss.backward() + self.optim["world"].step() + + metrics = { + "loss/world_loss": world_loss.item() + } + return metrics + + def update_reward(self, obs_1, obs_2, action_1, action_2, label) -> Dict: + F_B, F_S = obs_1.shape[0:2] + all_obs = torch.concat([ + obs_1.reshape(-1, self.obs_dim), + obs_2.reshape(-1, self.obs_dim) + ]) + all_action = torch.concat([ + action_1.reshape(-1, self.action_dim), + action_2.reshape(-1, self.action_dim) + ]) + self.network.reward.train() + all_reward = self.network.reward(torch.concat([all_obs, all_action], dim=-1)) + r1, r2 = torch.chunk(all_reward, 2, dim=1) + E = r1.shape[0] + r1, r2 = r1.reshape(E, F_B, F_S, 1), r2.reshape(E, F_B, F_S, 1) + logits = r2.sum(dim=2) - r1.sum(dim=2) + labels = label.float().unsqueeze(0).expand_as(logits) + reward_loss = self.reward_criterion(logits, labels).sum(0).mean() + reg_loss = (r1**2).sum(0).mean() + (r2**2).sum(0).mean() + # hindsight prior loss by attention weights + obs = torch.concat([obs_1, obs_2], dim=0) + action = torch.concat([action_1, action_2], dim=0) + r = torch.concat([r1, r2], dim=1) + predicted_return = r.sum(dim=2) + _, attentions = self.network.world(obs, action, None) + # get last attentions -> [F_B, num_layers, 2 * F_S] + attentions = torch.stack([attn[:, -1] for attn in attentions], dim=1) + # alpha = 1/L * sum_{l=1}^{L} (attn_{s_t}^l + attn_{a_t}^l) -> [F_B, F_S] + attentions = attentions.reshape(*attentions.shape[:-1], -1, 2).sum(dim=-1) + prior_importance = attentions.mean(dim=1) + r_target = prior_importance * predicted_return + r_target = r_target.unsqueeze(-1) + prior_loss = self.prior_criterion(r, r_target).mean() + with torch.no_grad(): + reward_accuracy = ((logits > 0) == torch.round(labels)).float().mean() + + self.optim["reward"].zero_grad() + (reward_loss + self.reward_reg * reg_loss + self.prior_coef * prior_loss).backward() + self.optim["reward"].step() + + metrics = { + "loss/reward_loss": reward_loss.item(), + "loss/reward_reg_loss": reg_loss.item(), + "loss/prior_loss": reg_loss.item(), + "misc/reward_acc": reward_accuracy.item(), + "misc/reward_value": all_reward.mean().item() + } + return metrics + + def train_step(self, batches, step: int, total_steps: int) -> Dict: + rl_batch = batches[0] + obs, action, next_obs, terminal = itemgetter("obs", "action", "next_obs", "terminal")(rl_batch) + terminal = terminal.float() + if self.rm_label: + reward = itemgetter("reward")(rl_batch) + else: + with torch.no_grad(): + reward = self.select_reward({"obs": obs, "action": action}, deterministic=True) + + # compute the loss for actor + actor_loss, advantage = self.actor_loss(obs, action) + self.optim["actor"].zero_grad() + actor_loss.backward() + self.optim["actor"].step() + + # compute the loss for q, offset by 1 + q_loss, q_pred = self.q_loss(obs, action, next_obs, reward, terminal) + self.optim["critic"].zero_grad() + q_loss.backward() + self.optim["critic"].step() + + for _, scheduler in self.schedulers.items(): + scheduler.step() + + if step % self.target_freq == 0: + sync_target(self.network.critic, self.target_network.critic, tau=self.tau) + + metrics = { + "loss/q_loss": q_loss.item(), + "loss/actor_loss": actor_loss.item(), + "misc/q_pred": q_pred.mean().item(), + "misc/advantage": advantage.mean().item() + } + return metrics + + def load_pretrain(self, path): + for attr in ["world", "reward"]: + state_dict = torch.load(os.path.join(path, attr+".pt"), map_location=self.device) + self.network.__getattr__(attr).load_state_dict(state_dict) + + def save_pretrain(self, path): + os.makedirs(path, exist_ok=True) + for attr in ["world", "reward"]: + state_dict = self.network.__getattr__(attr).state_dict() + torch.save(state_dict, os.path.join(path, attr+".pt")) diff --git a/wiserl/algorithm/hprior/hprior_iql.py b/wiserl/algorithm/hprior/hprior_iql.py new file mode 100644 index 0000000..29d926c --- /dev/null +++ b/wiserl/algorithm/hprior/hprior_iql.py @@ -0,0 +1,240 @@ +import itertools +import os +from operator import itemgetter +from typing import Any, Dict, Optional, Type + +import numpy as np +import torch +import torch.nn as nn + +import wiserl.module +from wiserl.algorithm.oracle_iql import OracleIQL +from wiserl.module.net.attention.twm import TransformerBasedWorldModel +from wiserl.utils.misc import sync_target + + +class Hindsight_PRIOR_IQL(OracleIQL): + def __init__( + self, + *args, + expectile: float = 0.7, + beta: float = 0.3333, + max_exp_clip: float = 100.0, + discount: float = 0.99, + tau: float = 0.005, + target_freq: int = 1, + reward_reg: float = 0.0, + max_seq_len: int = 100, + world_steps: int = 50000, + prior_coef: float = 1.0, + rm_label: bool = True, + **kwargs + ) -> None: + self.max_seq_len = max_seq_len + super().__init__( + *args, + expectile=expectile, + beta=beta, + max_exp_clip=max_exp_clip, + discount=discount, + tau=tau, + target_freq=target_freq, + **kwargs + ) + self.reward_reg = reward_reg + self.prior_coef = prior_coef + self.world_steps = world_steps + self.rm_label = rm_label + assert rm_label + self.obs_dim = self.observation_space.shape[0] + self.action_dim = self.action_space.shape[0] + + self.world_criterion = torch.nn.L1Loss(reduction="none") + self.reward_criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + self.prior_criterion = torch.nn.MSELoss(reduction="none") + + def setup_network(self, network_kwargs): + super().setup_network(network_kwargs) + # world model + world_kwargs = network_kwargs["world"] + world = TransformerBasedWorldModel( + obs_dim=self.observation_space.shape[0], + action_dim=self.action_space.shape[0], + embed_dim=world_kwargs["embed_dim"], + num_layers=world_kwargs["num_layers"], + seq_len=self.max_seq_len, + num_heads=world_kwargs["num_heads"], + ) + self.network["world"] = world + # reward model + reward_act = { + "identity": nn.Identity(), + "sigmoid": nn.Sigmoid(), + }.get(network_kwargs["reward"].pop("reward_act")) + reward = vars(wiserl.module)[network_kwargs["reward"].pop("class")]( + input_dim=self.observation_space.shape[0]+self.action_space.shape[0], + output_dim=1, + **network_kwargs["reward"] + ) + self.network["reward"] = nn.Sequential(self.network["encoder"], reward, reward_act) + + def setup_optimizers(self, optim_kwargs): + super().setup_optimizers(optim_kwargs) + default_kwargs = optim_kwargs.get("default", {}) + for attr in ["world", "reward"]: + kwargs = default_kwargs.copy() + kwargs.update(optim_kwargs.get(attr, {})) + optim = vars(torch.optim)[kwargs.pop("class")]( + self.network.__getattr__(attr).parameters(), **kwargs + ) + self.optim[attr] = optim + + def select_action(self, batch, deterministic: bool=True): + return super().select_action(batch, deterministic) + + def select_reward(self, batch, deterministic=False): + obs, action = batch["obs"], batch["action"] + reward = self.network.reward(torch.concat([obs, action], dim=-1)) + return reward.mean(0).detach() + + def pretrain_step(self, batches, step: int, total_steps: int) -> Dict: + traj_batch, pref_batch = batches + if step < self.world_steps: + return self.update_world( + obs=traj_batch["obs"], + action=traj_batch["action"], + next_obs=traj_batch["next_obs"], + ) + else: + return self.update_reward( + obs_1=pref_batch["obs_1"], + obs_2=pref_batch["obs_2"], + action_1=pref_batch["action_1"], + action_2=pref_batch["action_2"], + label=pref_batch["label"], + ) + + def update_world(self, obs, action, next_obs) -> Dict: + # obs [F_B, F_S, obs_dim], action [F_B, F_S, action_dim] + F_B, F_S = obs.shape[:2] + timestep = torch.arange(F_S, device=self.device).unsqueeze(0).expand(F_B, -1) + pred_obs, attentions = self.network.world(obs, action, timestep) + world_loss = self.world_criterion(pred_obs, next_obs).sum(0).mean() + + self.optim["world"].zero_grad() + world_loss.backward() + self.optim["world"].step() + + metrics = { + "loss/world_loss": world_loss.item() + } + return metrics + + def update_reward(self, obs_1, obs_2, action_1, action_2, label) -> Dict: + F_B, F_S = obs_1.shape[0:2] + all_obs = torch.concat([ + obs_1.reshape(-1, self.obs_dim), + obs_2.reshape(-1, self.obs_dim) + ]) + all_action = torch.concat([ + action_1.reshape(-1, self.action_dim), + action_2.reshape(-1, self.action_dim) + ]) + self.network.reward.train() + all_reward = self.network.reward(torch.concat([all_obs, all_action], dim=-1)) + r1, r2 = torch.chunk(all_reward, 2, dim=1) + E = r1.shape[0] + r1, r2 = r1.reshape(E, F_B, F_S, 1), r2.reshape(E, F_B, F_S, 1) + logits = r2.sum(dim=2) - r1.sum(dim=2) + labels = label.float().unsqueeze(0).expand_as(logits) + reward_loss = self.reward_criterion(logits, labels).sum(0).mean() + reg_loss = (r1**2).sum(0).mean() + (r2**2).sum(0).mean() + # hindsight prior loss by attention weights + obs = torch.concat([obs_1, obs_2], dim=0) + action = torch.concat([action_1, action_2], dim=0) + r = torch.concat([r1, r2], dim=1) + predicted_return = r.sum(dim=2) + _, attentions = self.network.world(obs, action, None) + # get last attentions -> [F_B, num_layers, 2 * F_S] + attentions = torch.stack([attn[:, -1] for attn in attentions], dim=1) + # alpha = 1/L * sum_{l=1}^{L} (attn_{s_t}^l + attn_{a_t}^l) -> [F_B, F_S] + attentions = attentions.reshape(*attentions.shape[:-1], -1, 2).sum(dim=-1) + prior_importance = attentions.mean(dim=1) + r_target = prior_importance * predicted_return + r_target = r_target.unsqueeze(-1) + prior_loss = self.prior_criterion(r, r_target).mean() + with torch.no_grad(): + reward_accuracy = ((logits > 0) == torch.round(labels)).float().mean() + + self.optim["reward"].zero_grad() + (reward_loss + self.reward_reg * reg_loss + self.prior_coef * prior_loss).backward() + self.optim["reward"].step() + + metrics = { + "loss/reward_loss": reward_loss.item(), + "loss/reward_reg_loss": reg_loss.item(), + "loss/prior_loss": reg_loss.item(), + "misc/reward_acc": reward_accuracy.item(), + "misc/reward_value": all_reward.mean().item() + } + return metrics + + def train_step(self, batches, step: int, total_steps: int) -> Dict: + rl_batch = batches[0] + obs, action, next_obs, terminal = itemgetter("obs", "action", "next_obs", "terminal")(rl_batch) + terminal = terminal.float() + if self.rm_label: + reward = itemgetter("reward")(rl_batch) + else: + with torch.no_grad(): + reward = self.select_reward({"obs": obs, "action": action}, deterministic=True) + + with torch.no_grad(): + self.target_network.eval() + q_old = self.target_network.critic(obs, action) + q_old = torch.min(q_old, dim=0)[0] + + # compute the loss for value network + v_loss, v_pred = self.v_loss(obs.detach(), q_old) + self.optim["value"].zero_grad() + v_loss.backward() + self.optim["value"].step() + + # compute the loss for actor + actor_loss, advantage = self.actor_loss(obs, action, q_old, v_pred.detach()) + self.optim["actor"].zero_grad() + actor_loss.backward() + self.optim["actor"].step() + + # compute the loss for q + q_loss, q_pred = self.q_loss(obs, action, next_obs, reward, terminal) + self.optim["critic"].zero_grad() + q_loss.backward() + self.optim["critic"].step() + + for _, scheduler in self.schedulers.items(): + scheduler.step() + + if step % self.target_freq == 0: + sync_target(self.network.critic, self.target_network.critic, tau=self.tau) + + metrics = { + "loss/q_loss": q_loss.item(), + "loss/v_loss": v_loss.item(), + "loss/actor_loss": actor_loss.item(), + "misc/q_pred": q_pred.mean().item(), + "misc/v_pred": v_pred.mean().item(), + "misc/advantage": advantage.mean().item() + } + return metrics + + def load_pretrain(self, path): + for attr in ["world", "reward"]: + state_dict = torch.load(os.path.join(path, attr+".pt"), map_location=self.device) + self.network.__getattr__(attr).load_state_dict(state_dict) + + def save_pretrain(self, path): + os.makedirs(path, exist_ok=True) + for attr in ["world", "reward"]: + state_dict = self.network.__getattr__(attr).state_dict() + torch.save(state_dict, os.path.join(path, attr+".pt")) diff --git a/wiserl/eval/__init__.py b/wiserl/eval/__init__.py index f422d06..7492d07 100644 --- a/wiserl/eval/__init__.py +++ b/wiserl/eval/__init__.py @@ -1,6 +1,6 @@ from wiserl.eval.cliff import eval_cliffwalking_rm from wiserl.eval.offline import eval_offline -from wiserl.eval.reward_model import eval_reward_model +from wiserl.eval.reward_model import eval_reward_model, eval_world_model, eval_world_model_and_reward_model def eval_placeholder(*args, **kwargs): diff --git a/wiserl/eval/reward_model.py b/wiserl/eval/reward_model.py index 10c6e22..a2332fd 100644 --- a/wiserl/eval/reward_model.py +++ b/wiserl/eval/reward_model.py @@ -41,3 +41,42 @@ def eval_reward_model( "val_loss": torch.as_tensor(rm_eval_loss).mean().item(), "val_acc": torch.as_tensor(rm_eval_acc).mean().item(), } + + +@torch.no_grad() +def eval_world_model( + env: gym.Env, + algorithm: Algorithm, + eval_dataset_kwargs: Optional[Sequence[str]], +): + wm_eval_loss = [] + kwargs = eval_dataset_kwargs.copy() + eval_dataset_class = kwargs.pop("class") + eval_dataset = vars(wiserl.dataset)[eval_dataset_class]( + env.observation_space, + env.action_space, + **kwargs + ) + for batch in eval_dataset.create_sequential_iter(): + batch = algorithm.format_batch(batch) + obs = torch.concat([batch["obs_1"], batch["obs_2"]], dim=0) + action = torch.concat([batch["action_1"], batch["action_2"]], dim=0) + next_obs = torch.roll(obs, shifts=-1, dims=1) + timestep = torch.arange(obs.shape[1], device=algorithm.device).unsqueeze(0).expand(obs.shape[0], -1) + pred_obs, _ = algorithm.network.world(obs, action, timestep) + # use the seq_len - 2 timestep + world_loss = algorithm.world_criterion(pred_obs[:, -2], next_obs[:, -2]).mean(-1) + wm_eval_loss.extend(world_loss) + return {"world_eval_loss": torch.as_tensor(wm_eval_loss).mean().item()} + + +@torch.no_grad() +def eval_world_model_and_reward_model( + env: gym.Env, + algorithm: Algorithm, + eval_dataset_kwargs: Optional[Sequence[str]], +): + metrics = {} + metrics.update(eval_world_model(env, algorithm, eval_dataset_kwargs)) + metrics.update(eval_reward_model(env, algorithm, eval_dataset_kwargs)) + return metrics \ No newline at end of file diff --git a/wiserl/module/net/attention/gpt2.py b/wiserl/module/net/attention/gpt2.py index a791f8f..959367a 100644 --- a/wiserl/module/net/attention/gpt2.py +++ b/wiserl/module/net/attention/gpt2.py @@ -39,24 +39,25 @@ def forward( self, input: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, - key_padding_mask: Optional[torch.Tensor]=None + key_padding_mask: Optional[torch.Tensor]=None, + need_weights: bool=False, ): if attention_mask is not None: attention_mask = attention_mask.to(torch.bool) residual = input input = self.ln1(input) - attn_output = self.attention( + attn_output, attn_output_weights = self.attention( query=input, key=input, value=input, - need_weights=False, + need_weights=need_weights, attn_mask=attention_mask, key_padding_mask=key_padding_mask - )[0] + ) residual = residual + self.drop(attn_output) # this is because pytorch MHV don't do dropout after final projection residual = residual + self.ff(self.ln2(residual)) - return residual + return residual, attn_output_weights class GPT2(BaseTransformer): @@ -96,7 +97,8 @@ def forward( timesteps: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, - do_embedding: bool=True + do_embedding: bool=True, + output_attentions: bool=False, ): B, L, *_ = inputs.shape if self.causal: @@ -111,7 +113,11 @@ def forward( inputs = self.input_embed(inputs) inputs = self.pos_encoding(inputs, timesteps) inputs = self.embed_dropout(inputs) + attentions = [] for i, block in enumerate(self.blocks): - inputs = block(inputs, attention_mask=mask, key_padding_mask=key_padding_mask) + inputs, weights = block(inputs, attention_mask=mask, key_padding_mask=key_padding_mask, need_weights=output_attentions) + attentions.append(weights) inputs = self.out_ln(inputs) + if output_attentions: + return inputs, tuple(attentions) return inputs diff --git a/wiserl/module/net/attention/twm.py b/wiserl/module/net/attention/twm.py new file mode 100644 index 0000000..f5af4e3 --- /dev/null +++ b/wiserl/module/net/attention/twm.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + +import torch +import torch.nn as nn + +from wiserl.module.net.attention.base import BaseTransformer +from wiserl.module.net.attention.gpt2 import GPT2 +from wiserl.module.net.attention.positional_encoding import get_pos_encoding + + +class TransformerBasedWorldModel(BaseTransformer): + def __init__( + self, + obs_dim: int, + action_dim: int, + embed_dim: int, + num_layers: int, + seq_len: int, + num_heads: int=1, + attention_dropout: Optional[float]=0.1, + residual_dropout: Optional[float]=0.1, + embed_dropout: Optional[float]=0.1, + pos_encoding: str="embed", + ) -> None: + super().__init__() + self.backbone = GPT2( + input_dim=embed_dim, + embed_dim=embed_dim, + num_layers=num_layers, + num_heads=num_heads, + causal=True, + attention_dropout=attention_dropout, + residual_dropout=residual_dropout, + embed_dropout=embed_dropout, + pos_encoding="none", + seq_len=0 + ) + self.pos_encoding = get_pos_encoding(pos_encoding, embed_dim, seq_len) + self.obs_embed = nn.Linear(obs_dim, embed_dim) + self.act_embed = nn.Linear(action_dim, embed_dim) + self.embed_ln = nn.LayerNorm(embed_dim) + self.obs_head = nn.Linear(embed_dim, obs_dim) + + def forward( + self, + observations: torch.Tensor, + actions: torch.Tensor, + timesteps: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + key_padding_mask: Optional[torch.Tensor]=None, + ): + B, L, *_ = observations.shape + state_embedding = self.pos_encoding(self.obs_embed(observations), timesteps) + action_embedding = self.pos_encoding(self.act_embed(actions), timesteps) + stacked_input = torch.stack([state_embedding, action_embedding], dim=2).reshape(B, 2*L, state_embedding.shape[-1]) + stacked_input = self.embed_ln(stacked_input) + if key_padding_mask is not None: + key_padding_mask = torch.stack([key_padding_mask, key_padding_mask], dim=2).reshape(B, 2*L) + out, attentions = self.backbone( + inputs=stacked_input, + timesteps=None, + attention_mask=attention_mask, + key_padding_mask=key_padding_mask, + do_embedding=False, + output_attentions=True, + ) + pred_obs = self.obs_head(out[:, 1::2]) # o[:t] + a[:t] -> s[t+1] + + return pred_obs, attentions