Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HIT policy #548

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lerobot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
# lists all available policies from `lerobot/common/policies`
available_policies = [
"act",
"hit",
"diffusion",
"tdmpc",
"vqbet",
Expand Down Expand Up @@ -216,7 +217,7 @@

# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"aloha": ["act", "hit"],
"pusht": ["diffusion", "vqbet"],
"xarm": ["tdmpc"],
"koch_real": ["act_koch_real"],
Expand Down
5 changes: 5 additions & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.act.modeling_act import ACTPolicy

return ACTPolicy, ACTConfig
elif name == "hit":
from lerobot.common.policies.hit.configuration_hit import HITConfig
from lerobot.common.policies.hit.modeling_hit import HITPolicy

return HITPolicy, HITConfig
elif name == "vqbet":
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
Expand Down
155 changes: 155 additions & 0 deletions lerobot/common/policies/hit/configuration_hit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#!/usr/bin/env python

# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field


@dataclass
class HITConfig:
"""Configuration class for the Humanoid Imitation Transformer policy.

Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.

Notes on the inputs and outputs:
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- "action" is required as an output key.

Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks.
dim_model: The transformer blocks' main hidden dimension.
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
layers.
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
n_layers: The number of transformer layers to use for the transformer encoder.
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
more information on how ensembling works, please see `ACTTemporalEnsembler`.
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""

# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50

input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [14],
}
)

# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)

# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
replace_final_stride_with_dilation: int = False
# Transformer layers.
pre_norm: bool = False
dim_model: int = 512
n_heads: int = 8
dim_feedforward: int = 3200
feedforward_activation: str = "gelu"
n_layers: int = 4
# Future image feature predict loss
predict_horizon: int = 50
feature_loss_weight: float = 0.005

# Inference.
# Note: the original HIT didn't use temporal ensembling.
temporal_ensemble_coeff: float | None = None

# Training and loss computation.
dropout: float = 0.1
kl_weight: float = 10.0

def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError("You must provide at least one image or the environment state among the inputs.")
Loading