Skip to content

[Feature] empty_lazy for lazy tensor storages #2955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/vmoens/141/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,9 @@ def extend_and_sample(data):
"`TensorStorage._rand_given_ndim` can be removed."
)

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize(
"storage_type", [partial(LazyTensorStorage, empty_lazy=True), LazyMemmapStorage]
)
def test_extend_lazystack(self, storage_type):

rb = ReplayBuffer(
Expand All @@ -881,9 +883,24 @@ def test_extend_lazystack(self, storage_type):
td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5)
ltd = LazyStackedTensorDict(td1, td2, stack_dim=1)
rb.extend(ltd)
rb.sample(3)
s = rb.sample(3)
assert isinstance(s, LazyStackedTensorDict)
assert len(rb) == 5

def test_extend_empty_lazy(self):

rb = ReplayBuffer(
storage=LazyTensorStorage(6, empty_lazy=True),
batch_size=2,
)
td1 = TensorDict(a=torch.rand(4, 8), batch_size=4)
td2 = TensorDict(a=torch.rand(3, 8), batch_size=3)
ltd = LazyStackedTensorDict(td1, td2, stack_dim=0)
rb.extend(ltd)
s = rb.sample(3)
assert isinstance(s, LazyStackedTensorDict)
assert len(rb) == 2

@pytest.mark.parametrize("device_data", get_default_devices())
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])
Expand Down
45 changes: 39 additions & 6 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def set(
if not self.initialized:
if not isinstance(cursor, INT_CLASSES):
if is_tensor_collection(data):
self._init(data[0])
self._init(data, shape=data.shape[1:])
else:
self._init(tree_map(lambda x: x[0], data))
else:
Expand Down Expand Up @@ -873,7 +873,7 @@ def set( # noqa: F811
)
if not self.initialized:
if not isinstance(cursor, INT_CLASSES):
self._init(data[0])
self._init(data, shape=data.shape[1:])
else:
self._init(data)
if not isinstance(cursor, (*INT_CLASSES, slice)):
Expand Down Expand Up @@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
Defaults to ``False``.
consolidated (bool, optional): if ``True``, the storage will be consolidated after
its first expansion. Defaults to ``False``.
empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
passed to the storage will be emptied of its content. This can be used to store
ragged data or content with exclusive keys (e.g., when some but not all environments
provide extra data to be stored in the buffer).
Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
will result in an exception).
Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
slower than contiguous data and serialization is more hazardous. Use with caution!
Defaults to ``False``.

Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -1054,6 +1063,7 @@ def __init__(
ndim: int = 1,
compilable: bool = False,
consolidated: bool = False,
empty_lazy: bool = False,
):
super().__init__(
storage=None,
Expand All @@ -1062,11 +1072,13 @@ def __init__(
ndim=ndim,
compilable=compilable,
)
self.empty_lazy = empty_lazy
self.consolidated = consolidated

def _init(
self,
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
shape: torch.Size | None = None,
) -> None:
if not self._compilable:
# TODO: Investigate why this seems to have a performance impact with
Expand All @@ -1087,8 +1099,21 @@ def max_size_along_dim0(data_shape):

if is_tensor_collection(data):
out = data.to(self.device)
out: TensorDictBase = torch.empty_like(
out.expand(max_size_along_dim0(data.shape))
if self.empty_lazy:
if shape is None:
# shape is None in add
raise RuntimeError(
"Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
)
out: TensorDictBase = torch.empty_like(
out.expand(max_size_along_dim0(data.shape))
)
elif shape is None:
shape = data.shape
else:
out = out[0]
out: TensorDictBase = out.new_empty(
max_size_along_dim0(shape), empty_lazy=self.empty_lazy
)
if self.consolidated:
out = out.consolidate()
Expand Down Expand Up @@ -1286,7 +1311,9 @@ def load_state_dict(self, state_dict):
self.initialized = state_dict["initialized"]
self._len = state_dict["_len"]

def _init(self, data: TensorDictBase | torch.Tensor) -> None:
def _init(
self, data: TensorDictBase | torch.Tensor, *, shape: torch.Size | None = None
) -> None:
torchrl_logger.debug("Creating a MemmapStorage...")
if self.device == "auto":
self.device = data.device
Expand All @@ -1304,8 +1331,14 @@ def max_size_along_dim0(data_shape):
return (self.max_size, *data_shape)

if is_tensor_collection(data):
if shape is None:
# Within add()
shape = data.shape
else:
# Get the first element - we don't care about empty_lazy in memmap storages
data = data[0]
out = data.clone().to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.expand(max_size_along_dim0(shape))
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
if torchrl_logger.isEnabledFor(logging.DEBUG):
for key, tensor in sorted(
Expand Down
7 changes: 4 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7308,12 +7308,13 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
return reward_spec

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"]
if not time_dim:
try:
time_dim = list(tensordict.names).index("time")
except ValueError:
raise ValueError(
"At least one dimension of the tensordict must be named 'time' in offline mode"
)
time_dim = time_dim[0] - 1
time_dim = time_dim - 1
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
reward = tensordict[in_key]
cumsum = reward.cumsum(time_dim)
Expand Down
Loading