-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torchcodec cpu #798
base: main
Are you sure you want to change the base?
Add torchcodec cpu #798
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice work Jade! Thanks :)
Let's wait for the next version of torchcodec then!
In the meantime, could you try reproducing results on pusht and aloha transfer cube? and adding the commands that you use and the success rate in the README?
THanks!
""" | ||
if backend == "torchcodec": | ||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) | ||
elif backend == "pyav": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To not break backward compatibility with video_reader
elif backend == "pyav": | |
elif backend in ["pyav", "video_reader"]: |
def decode_video_frames_torchcodec( | ||
video_path: Path | str, | ||
timestamps: list[float], | ||
tolerance_s: float, | ||
device: str = "cpu", | ||
log_loaded_timestamps: bool = False, | ||
) -> torch.Tensor: | ||
"""Loads frames associated with the requested timestamps of a video using torchcodec.""" | ||
video_path = str(video_path) | ||
# initialize video decoder | ||
decoder = VideoDecoder(video_path, device=device) | ||
loaded_frames = [] | ||
loaded_ts = [] | ||
# get metadata for frame information | ||
metadata = decoder.metadata | ||
average_fps = metadata.average_fps | ||
|
||
# convert timestamps to frame indices | ||
frame_indices = [round(ts * average_fps) for ts in timestamps] | ||
|
||
# retrieve frames based on indices | ||
frames_batch = decoder.get_frames_at(indices=frame_indices) | ||
|
||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): | ||
loaded_frames.append(frame) | ||
loaded_ts.append(pts.item()) | ||
if log_loaded_timestamps: | ||
logging.info(f"Frame loaded at timestamp={pts:.4f}") | ||
|
||
query_ts = torch.tensor(timestamps) | ||
loaded_ts = torch.tensor(loaded_ts) | ||
|
||
# compute distances between each query timestamp and loaded timestamps | ||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) | ||
min_, argmin_ = dist.min(1) | ||
|
||
is_within_tol = min_ < tolerance_s | ||
assert is_within_tol.all(), ( | ||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." | ||
"It means that the closest frame that can be loaded from the video is too far away in time." | ||
"This might be due to synchronization issues with timestamps during data collection." | ||
"To be safe, we advise to ignore this item during training." | ||
f"\nqueried timestamps: {query_ts}" | ||
f"\nloaded timestamps: {loaded_ts}" | ||
f"\nvideo: {video_path}" | ||
) | ||
|
||
# get closest frames to the query timestamps | ||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) | ||
closest_ts = loaded_ts[argmin_] | ||
|
||
if log_loaded_timestamps: | ||
logging.info(f"{closest_ts=}") | ||
|
||
# convert to float32 in [0,1] range (channel first) | ||
closest_frames = closest_frames.type(torch.float32) / 255 | ||
|
||
assert len(timestamps) == len(closest_frames) | ||
return closest_frames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to not refactor into a single function decode_video_frames
with pyav and torchcodec as backend since we plan to deprecate pyav.
Args: | ||
video_path (Path): Path to the video file. | ||
query_ts (list[float]): List of timestamps to extract frames. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args are wrong
device: str = "cpu", | ||
log_loaded_timestamps: bool = False, | ||
) -> torch.Tensor: | ||
"""Loads frames associated with the requested timestamps of a video using torchcodec.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Loads frames associated with the requested timestamps of a video using torchcodec.""" | |
"""Loads frames associated with the requested timestamps of a video using torchcodec. | |
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. | |
Note: Video benefits from inter-frame compression. Instead of storing every frame individually, | |
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to | |
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, | |
and all subsequent frames until reaching the requested frame. The number of key frames in a video | |
can be adjusted during encoding to take into account decoding time and video size in bytes. | |
""" |
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently | ||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. | ||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. | ||
You can also use the 'pyav' decoder used by Torchvision. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use the 'pyav' decoder used by Torchvision. | |
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. |
What this does
This PR replaces torchvision CPU decoding by torchcodec CPU decoding.
Also added a
decode_video_frames
function that wraps multiple backends, instead of callingdecode_video_frames_BACKENDNAME
separately. This makes it more efficient and allows us to add more decoders later on!How it was tested
Test and Benchmark the decoders on different datasets/policies.
How to checkout & try? (for the reviewer)
Just run the training script, with a dataset containing videos to decode.
example:
Benchmarks
Ran one benchmark on
lerobot/aloha_sim_insertion_human_image dataset
Comparison: PyAV vs TorchCodec (CPU)
What's left
Remove/suppress
libdav1d
logs (they're noisy) -> there's no env variable to disable those for now but they'll be deactivated in the next version of torchcodec.