-
Notifications
You must be signed in to change notification settings - Fork 865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add transformer model to diffusion policy #481
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,7 @@ class DiffusionConfig: | |
|
||
# Inputs / output structure. | ||
n_obs_steps: int = 2 | ||
horizon: int = 16 | ||
horizon: int = 10 | ||
n_action_steps: int = 8 | ||
|
||
input_shapes: dict[str, list[int]] = field( | ||
|
@@ -134,7 +134,7 @@ class DiffusionConfig: | |
down_dims: tuple[int, ...] = (512, 1024, 2048) | ||
kernel_size: int = 5 | ||
n_groups: int = 8 | ||
diffusion_step_embed_dim: int = 128 | ||
diffusion_step_embed_dim: int = 256 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this change should be reverted right? To keep in line with the default PushT policy. |
||
use_film_scale_modulation: bool = True | ||
# Noise scheduler. | ||
noise_scheduler_type: str = "DDPM" | ||
|
@@ -145,6 +145,14 @@ class DiffusionConfig: | |
prediction_type: str = "epsilon" | ||
clip_sample: bool = True | ||
clip_sample_range: float = 1.0 | ||
# Transformer | ||
use_transformer: bool = True | ||
n_layer: int = 8 | ||
n_head: int = 4 | ||
p_drop_emb: float = 0.0 | ||
p_drop_attn: float = 0.3 | ||
causal_attn: bool = True | ||
n_cond_layers: int = 0 | ||
|
||
# Inference | ||
num_inference_steps: int | None = None | ||
|
@@ -200,7 +208,7 @@ def __post_init__(self): | |
# Check that the horizon size and U-Net downsampling is compatible. | ||
# U-Net downsamples by 2 with each stage. | ||
downsampling_factor = 2 ** len(self.down_dims) | ||
if self.horizon % downsampling_factor != 0: | ||
if not self.use_transformer and self.horizon % downsampling_factor != 0: | ||
raise ValueError( | ||
"The horizon should be an integer multiple of the downsampling factor (which is determined " | ||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
|
||
import math | ||
from collections import deque | ||
from typing import Callable | ||
from typing import Callable, Tuple | ||
|
||
import einops | ||
import numpy as np | ||
|
@@ -188,7 +188,12 @@ def __init__(self, config: DiffusionConfig): | |
self._use_env_state = True | ||
global_cond_dim += config.input_shapes["observation.environment_state"][0] | ||
|
||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) | ||
if config.use_transformer: | ||
self.net = TransformerForDiffusion(config, cond_dim=global_cond_dim) | ||
else: | ||
self.net = DiffusionConditionalUnet1d( | ||
config, global_cond_dim=global_cond_dim * config.n_obs_steps | ||
) | ||
|
||
self.noise_scheduler = _make_noise_scheduler( | ||
config.noise_scheduler_type, | ||
|
@@ -206,6 +211,20 @@ def __init__(self, config: DiffusionConfig): | |
else: | ||
self.num_inference_steps = config.num_inference_steps | ||
|
||
def get_optimizer( | ||
self, | ||
transformer_weight_decay: float = 1e-3, | ||
rgb_encoder_weight_decay: float = 1e-6, | ||
learning_rate: float = 1e-4, | ||
betas: Tuple[float, float] = [0.9, 0.95], | ||
) -> torch.optim.Optimizer: | ||
optim_groups = self.net.get_optim_groups(weight_decay=transformer_weight_decay) | ||
optim_groups.append( | ||
{"params": self.rgb_encoder.parameters(), "weight_decay": rgb_encoder_weight_decay} | ||
) | ||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) | ||
return optimizer | ||
|
||
# ========= inference ============ | ||
def conditional_sample( | ||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None | ||
|
@@ -225,7 +244,7 @@ def conditional_sample( | |
|
||
for t in self.noise_scheduler.timesteps: | ||
# Predict model output. | ||
model_output = self.unet( | ||
model_output = self.net( | ||
sample, | ||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), | ||
global_cond=global_cond, | ||
|
@@ -324,7 +343,7 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: | |
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) | ||
|
||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). | ||
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) | ||
pred = self.net(noisy_trajectory, timesteps, global_cond=global_cond) | ||
|
||
# Compute the loss. | ||
# The target is either the original trajectory, or the noise. | ||
|
@@ -749,3 +768,264 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor: | |
out = self.conv2(out) | ||
out = out + self.residual_conv(x) | ||
return out | ||
|
||
|
||
class TransformerForDiffusion(nn.Module): | ||
def __init__(self, config: DiffusionConfig, cond_dim: int): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I please ask for either more descriptive variable names or comments describing what the variables mean here? I've highlighted at least 1 or 2 specific asks below, but I realized it might be better to make this general request. Please check ACT for inspiration |
||
super().__init__() | ||
self.config = config | ||
|
||
# compute number of tokens for main trunk and condition encoder | ||
if config.n_obs_steps is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. config.n_obs_steps is None does not seem to be allowed according to the type hinting and documentation. So perhaps it doesn't make sense to handle it here, right? |
||
config.n_obs_steps = config.horizon | ||
|
||
t = config.horizon | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, is it okay if we just leave this as config.horizon rather than binding it to another much less descriptive variable name? |
||
t_cond = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So what is |
||
t_cond += config.n_obs_steps | ||
|
||
input_dim = config.output_shapes["action"][0] | ||
# input embedding stem | ||
self.input_emb = nn.Linear(input_dim, config.diffusion_step_embed_dim) | ||
self.pos_emb = nn.Parameter(torch.zeros(1, t, config.diffusion_step_embed_dim)) | ||
self.drop = nn.Dropout(config.p_drop_emb) | ||
|
||
# cond encoder | ||
self.time_emb = DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.cond_obs_emb = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line appears to be redundant. |
||
|
||
self.cond_obs_emb = nn.Linear(cond_dim, config.diffusion_step_embed_dim) | ||
|
||
self.cond_pos_emb = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line appears to be redundant. |
||
self.encoder = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two lines also. |
||
self.decoder = None | ||
|
||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, t_cond, config.diffusion_step_embed_dim)) | ||
if config.n_cond_layers > 0: | ||
encoder_layer = nn.TransformerEncoderLayer( | ||
d_model=config.diffusion_step_embed_dim, | ||
nhead=config.n_head, | ||
dim_feedforward=4 * config.diffusion_step_embed_dim, | ||
dropout=config.p_drop_attn, | ||
activation="gelu", | ||
batch_first=True, | ||
norm_first=True, | ||
) | ||
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=config.n_cond_layers) | ||
else: | ||
self.encoder = nn.Sequential( | ||
nn.Linear(config.diffusion_step_embed_dim, 4 * config.diffusion_step_embed_dim), | ||
nn.Mish(), | ||
nn.Linear(4 * config.diffusion_step_embed_dim, config.diffusion_step_embed_dim), | ||
) | ||
# decoder | ||
decoder_layer = nn.TransformerDecoderLayer( | ||
d_model=config.diffusion_step_embed_dim, | ||
nhead=config.n_head, | ||
dim_feedforward=4 * config.diffusion_step_embed_dim, | ||
dropout=config.p_drop_attn, | ||
activation="gelu", | ||
batch_first=True, | ||
norm_first=True, # important for stability | ||
) | ||
self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=config.n_layer) | ||
|
||
# attention mask | ||
if config.causal_attn: | ||
# causal mask to ensure that attention is only applied to the left in the input sequence | ||
# torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT | ||
# therefore, the upper triangle should be -inf and others (including diag) should be 0. | ||
sz = t | ||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | ||
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) | ||
self.register_buffer("mask", mask) | ||
|
||
# assume conditioning over time and observation both | ||
p, q = torch.meshgrid(torch.arange(t), torch.arange(t_cond), indexing="ij") | ||
mask = p >= (q - 1) # add one dimension since time is the first token in cond | ||
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) | ||
self.register_buffer("memory_mask", mask) | ||
else: | ||
self.mask = None | ||
self.memory_mask = None | ||
|
||
# decoder head | ||
self.ln_f = nn.LayerNorm(config.diffusion_step_embed_dim) | ||
self.head = nn.Linear(config.diffusion_step_embed_dim, input_dim) | ||
|
||
# constants | ||
self.t = t | ||
self.t_cond = t_cond | ||
self.horizon = config.horizon | ||
self.n_obs_steps = config.n_obs_steps | ||
|
||
# init | ||
self.apply(self._init_weights) | ||
# logger.info( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please remove this commented code? |
||
# "number of parameters: %e", sum(p.numel() for p in self.parameters()) | ||
# ) | ||
|
||
def _init_weights(self, module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there anything we can say in a docstring to summarize the weight initialization strategy used here? |
||
ignore_types = ( | ||
nn.Dropout, | ||
DiffusionSinusoidalPosEmb, | ||
nn.TransformerEncoderLayer, | ||
nn.TransformerDecoderLayer, | ||
nn.TransformerEncoder, | ||
nn.TransformerDecoder, | ||
nn.ModuleList, | ||
nn.Mish, | ||
nn.Sequential, | ||
) | ||
if isinstance(module, (nn.Linear, nn.Embedding)): | ||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | ||
if isinstance(module, nn.Linear) and module.bias is not None: | ||
torch.nn.init.zeros_(module.bias) | ||
elif isinstance(module, nn.MultiheadAttention): | ||
weight_names = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"] | ||
for name in weight_names: | ||
weight = getattr(module, name) | ||
if weight is not None: | ||
torch.nn.init.normal_(weight, mean=0.0, std=0.02) | ||
|
||
bias_names = ["in_proj_bias", "bias_k", "bias_v"] | ||
for name in bias_names: | ||
bias = getattr(module, name) | ||
if bias is not None: | ||
torch.nn.init.zeros_(bias) | ||
elif isinstance(module, nn.LayerNorm): | ||
torch.nn.init.zeros_(module.bias) | ||
torch.nn.init.ones_(module.weight) | ||
elif isinstance(module, TransformerForDiffusion): | ||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) | ||
if module.cond_obs_emb is not None: | ||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) | ||
elif isinstance(module, ignore_types): | ||
# no param | ||
pass | ||
else: | ||
raise RuntimeError("Unaccounted module {}".format(module)) | ||
|
||
def get_optim_groups(self, weight_decay: float = 1e-3): | ||
""" | ||
This long function is unfortunately doing something very simple and is being very defensive: | ||
We are separating out all parameters of the model into two buckets: those that will experience | ||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). | ||
We are then returning the PyTorch optimizer object. | ||
""" | ||
|
||
# separate out all parameters to those that will and won't experience regularizing weight decay | ||
decay = set() | ||
no_decay = set() | ||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) | ||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | ||
for mn, m in self.named_modules(): | ||
for pn, _ in m.named_parameters(): | ||
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name | ||
|
||
if pn.endswith("bias"): | ||
# all biases will not be decayed | ||
no_decay.add(fpn) | ||
elif pn.startswith("bias"): | ||
# MultiheadAttention bias starts with "bias" | ||
no_decay.add(fpn) | ||
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): | ||
# weights of whitelist modules will be weight decayed | ||
decay.add(fpn) | ||
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): | ||
# weights of blacklist modules will NOT be weight decayed | ||
no_decay.add(fpn) | ||
|
||
# special case the position embedding parameter in the root GPT module as not decayed | ||
no_decay.add("pos_emb") | ||
# no_decay.add("_dummy_variable") | ||
if self.cond_pos_emb is not None: | ||
no_decay.add("cond_pos_emb") | ||
|
||
# validate that we considered every parameter | ||
# param_dict = {pn: p for pn, p in self.named_parameters()} | ||
param_dict = dict(self.named_parameters()) | ||
inter_params = decay & no_decay | ||
union_params = decay | no_decay | ||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( | ||
str(inter_params) | ||
) | ||
assert ( | ||
len(param_dict.keys() - union_params) == 0 | ||
), "parameters {} were not separated into either decay/no_decay set!".format( | ||
str(param_dict.keys() - union_params), | ||
) | ||
|
||
# create the pytorch optimizer object | ||
optim_groups = [ | ||
{ | ||
"params": [param_dict[pn] for pn in sorted(decay)], | ||
"weight_decay": weight_decay, | ||
}, | ||
{ | ||
"params": [param_dict[pn] for pn in sorted(no_decay)], | ||
"weight_decay": 0.0, | ||
}, | ||
] | ||
return optim_groups | ||
|
||
def configure_optimizers( | ||
self, | ||
learning_rate: float = 1e-4, | ||
weight_decay: float = 1e-3, | ||
betas: Tuple[float, float] = (0.9, 0.95), | ||
): | ||
optim_groups = self.get_optim_groups(weight_decay=weight_decay) | ||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) | ||
return optimizer | ||
|
||
def forward(self, sample: torch.Tensor, timestep: torch.Tensor, global_cond: torch.Tensor, **kwargs): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we please tidy up this docstring?
|
||
x: (B,T,input_dim) | ||
timestep: (B,) | ||
global_cond: (B, global_cond_dim) | ||
output: (B,T,input_dim) | ||
""" | ||
# 1. time | ||
timesteps = timestep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason you assign another variable to the same object, and with such a similar name? |
||
batch_size = sample.shape[0] | ||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need this at the moment. |
||
timesteps = timesteps.expand(batch_size) | ||
time_emb = self.time_emb(timesteps).unsqueeze(1) | ||
# (B,1,n_emb) | ||
|
||
cond = einops.rearrange(global_cond, "b (s n) ... -> b s (n ...)", b=batch_size, s=self.n_obs_steps) | ||
# (B,To,n_cond) | ||
|
||
# process input | ||
input_emb = self.input_emb(sample) | ||
|
||
# encoder | ||
cond_embeddings = time_emb | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: mind dropping this line and just putting |
||
# (B,1,n_emb) | ||
|
||
cond_obs_emb = self.cond_obs_emb(cond) | ||
# (B,To,n_emb) | ||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) | ||
# (B,To + 1,n_emb) | ||
|
||
tc = cond_embeddings.shape[1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
position_embeddings = self.cond_pos_emb[:, :tc, :] # each position maps to a (learnable) vector | ||
x = self.drop(cond_embeddings + position_embeddings) | ||
x = self.encoder(x) | ||
memory = x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really have to add another variable into the namespace here? |
||
# (B,T_cond,n_emb) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make comments like this either in-line, or on the line preceding the line of code of concern? I think putting code then comment on the next line is rather unconventional. |
||
|
||
# decoder | ||
token_embeddings = input_emb | ||
t = token_embeddings.shape[1] | ||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector | ||
x = self.drop(token_embeddings + position_embeddings) | ||
# (B,T,n_emb) | ||
x = self.decoder(tgt=x, memory=memory, tgt_mask=self.mask, memory_mask=self.memory_mask) | ||
# (B,T,n_emb) | ||
|
||
# head | ||
x = self.ln_f(x) | ||
x = self.head(x) | ||
# (B,T,n_inp) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change should be reverted right? To keep in line with the default PushT policy.