Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchcodec cpu #798

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jadechoghari
Copy link

@jadechoghari jadechoghari commented Mar 3, 2025

What this does

This PR replaces torchvision CPU decoding by torchcodec CPU decoding.
Also added a decode_video_frames function that wraps multiple backends, instead of calling decode_video_frames_BACKENDNAME separately. This makes it more efficient and allows us to add more decoders later on!

How it was tested

Test and Benchmark the decoders on different datasets/policies.

How to checkout & try? (for the reviewer)

Just run the training script, with a dataset containing videos to decode.
example:

python lerobot/scripts/train.py \
    --output_dir=outputs/train/act_aloha_insertion \
    --policy.type=act \
    --dataset.repo_id=lerobot/aloha_sim_insertion_human \
    --env.type=aloha \
    --env.task=AlohaInsertion-v0 \

Benchmarks

Ran one benchmark on lerobot/aloha_sim_insertion_human_image dataset
Comparison: PyAV vs TorchCodec (CPU)

Metric PyAV TorchCodec-CPU
Video to Images Load Time Ratio 1.87 1.25
Avg MSE 5.14e-05 4.88e-05
Avg PSNR 43.17 43.37
Avg SSIM 0.995 0.995

What's left

Remove/suppress libdav1d logs (they're noisy) -> there's no env variable to disable those for now but they'll be deactivated in the next version of torchcodec.

@jadechoghari jadechoghari marked this pull request as draft March 3, 2025 06:49
@jadechoghari jadechoghari marked this pull request as ready for review March 3, 2025 07:32
@Cadene Cadene self-requested a review March 4, 2025 08:31
Copy link
Collaborator

@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice work Jade! Thanks :)

Let's wait for the next version of torchcodec then!

In the meantime, could you try reproducing results on pusht and aloha transfer cube? and adding the commands that you use and the success rate in the README?

THanks!

"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To not break backward compatibility with video_reader

Suggested change
elif backend == "pyav":
elif backend in ["pyav", "video_reader"]:

Comment on lines +157 to +215
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
video_path = str(video_path)
# initialize video decoder
decoder = VideoDecoder(video_path, device=device)
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps

# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]

# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)

for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")

query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)

# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)

is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)

# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]

if log_loaded_timestamps:
logging.info(f"{closest_ts=}")

# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255

assert len(timestamps) == len(closest_frames)
return closest_frames
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to not refactor into a single function decode_video_frames with pyav and torchcodec as backend since we plan to deprecate pyav.

Comment on lines +42 to +44
Args:
video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args are wrong

device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""

video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can also use the 'pyav' decoder used by Torchvision.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants