Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchcodec cpu #798

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
14 changes: 6 additions & 8 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames_torchvision,
decode_video_frames,
encode_video_frames,
get_video_info,
)
Expand Down Expand Up @@ -462,8 +462,8 @@ def __init__(
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision.
"""
super().__init__()
self.repo_id = repo_id
Expand All @@ -473,7 +473,7 @@ def __init__(
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav"
self.video_backend = video_backend if video_backend else "torchcodec"
self.delta_indices = None

# Unused attributes
Expand Down Expand Up @@ -707,9 +707,7 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
item[vid_key] = frames.squeeze(0)

return item
Expand Down Expand Up @@ -1029,7 +1027,7 @@ def create(
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
return obj


Expand Down
88 changes: 88 additions & 0 deletions lerobot/common/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder


def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames.
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")


def decode_video_frames_torchvision(
Expand Down Expand Up @@ -127,6 +154,67 @@ def decode_video_frames_torchvision(
return closest_frames


def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
video_path = str(video_path)
# initialize video decoder
decoder = VideoDecoder(video_path, device=device)
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps

# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]

# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)

for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")

query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)

# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)

is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)

# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]

if log_loaded_timestamps:
logging.info(f"{closest_ts=}")

# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255

assert len(timestamps) == len(closest_frames)
return closest_frames
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to not refactor into a single function decode_video_frames with pyav and torchcodec as backend since we plan to deprecate pyav.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! and this keeps it compatible, note that lerobot models contain a "video_backend": "pyav" key inside their train_config.json



def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchvision>=0.21.0",
"torchcodec>=0.1.0",
"wandb>=0.16.3",
"zarr>=2.17.0",
]
Expand Down