Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming to LeRobotDataset #740

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
get_hf_features_from_features,
get_hub_safe_version,
hf_transform_to_torch,
item_to_torch,
load_episodes,
load_info,
load_stats,
Expand Down Expand Up @@ -214,6 +215,9 @@ def get_task_index(self, task: str) -> int:
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks

def html_root(self) -> str:
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"

def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
Expand Down Expand Up @@ -334,6 +338,7 @@ def __init__(
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
streaming: bool = False,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
Expand Down Expand Up @@ -431,6 +436,8 @@ def __init__(
will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
streaming (bool, optional): If set to True, don't download the data files. Instead, it streams the data
progressively while iterating on the dataset. Default to False.
"""
super().__init__()
self.repo_id = repo_id
Expand All @@ -440,10 +447,11 @@ def __init__(
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
self.streaming = streaming

# Unused attributes
self.delta_indices = None
self.image_writer = None
self.episode_buffer = None

Expand All @@ -456,16 +464,21 @@ def __init__(
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)

# Load actual data
self.download_episodes(download_videos)
if not self.streaming:
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
if self.streaming:
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)

# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if not self.streaming:
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)

# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
if not self.streaming:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)

# Available stats implies all videos have been encoded and dataset is iterable
Expand Down Expand Up @@ -550,13 +563,14 @@ def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
hf_dataset = load_dataset("parquet", data_dir=path, split="train", streaming=self.streaming)
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset = load_dataset("parquet", data_files=files, split="train", streaming=self.streaming)

# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
if not self.streaming:
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)

return hf_dataset

Expand Down Expand Up @@ -632,7 +646,8 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
root = self.meta.html_root if self.streaming else self.root
video_path = Path(root) / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
Expand All @@ -649,7 +664,15 @@ def __len__(self):
return self.num_frames

def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx]
if self.streaming:
try:
item = next(self.hf_dataset_iter)
except StopIteration:
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
item = next(self.hf_dataset_iter)
item = item_to_torch(item)
else:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()

query_indices = None
Expand Down
12 changes: 12 additions & 0 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict


def item_to_torch(item: dict):
for key, value in item.items():
if isinstance(value, PILImage.Image):
to_tensor = transforms.ToTensor()
item[key] = to_tensor(value)
elif value is None or isinstance(value, str):
pass
else:
item[key] = torch.tensor(value)
return item


def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
Expand Down
Loading