Skip to content

Commit 2d748f6

Browse files
committed
[Feature] empty_lazy for lazy tensor storages
ghstack-source-id: 5207da9 Pull-Request-resolved: #2955
1 parent 36f34da commit 2d748f6

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

test/test_rb.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,9 @@ def extend_and_sample(data):
870870
"`TensorStorage._rand_given_ndim` can be removed."
871871
)
872872

873-
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
873+
@pytest.mark.parametrize(
874+
"storage_type", [partial(LazyTensorStorage, empty_lazy=True), LazyMemmapStorage]
875+
)
874876
def test_extend_lazystack(self, storage_type):
875877

876878
rb = ReplayBuffer(
@@ -881,7 +883,8 @@ def test_extend_lazystack(self, storage_type):
881883
td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5)
882884
ltd = LazyStackedTensorDict(td1, td2, stack_dim=1)
883885
rb.extend(ltd)
884-
rb.sample(3)
886+
s = rb.sample(3)
887+
assert isinstance(s, LazyStackedTensorDict)
885888
assert len(rb) == 5
886889

887890
@pytest.mark.parametrize("device_data", get_default_devices())

torchrl/data/replay_buffers/storages.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def set(
827827
if not self.initialized:
828828
if not isinstance(cursor, INT_CLASSES):
829829
if is_tensor_collection(data):
830-
self._init(data[0])
830+
self._init(data, shape=data.shape[1:])
831831
else:
832832
self._init(tree_map(lambda x: x[0], data))
833833
else:
@@ -873,7 +873,7 @@ def set( # noqa: F811
873873
)
874874
if not self.initialized:
875875
if not isinstance(cursor, INT_CLASSES):
876-
self._init(data[0])
876+
self._init(data, shape=data.shape[1:])
877877
else:
878878
self._init(data)
879879
if not isinstance(cursor, (*INT_CLASSES, slice)):
@@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
993993
Defaults to ``False``.
994994
consolidated (bool, optional): if ``True``, the storage will be consolidated after
995995
its first expansion. Defaults to ``False``.
996+
empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
997+
passed to the storage will be emptied of its content. This can be used to store
998+
ragged data or content with exclusive keys (e.g., when some but not all environments
999+
provide extra data to be stored in the buffer).
1000+
Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
1001+
will result in an exception).
1002+
Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
1003+
slower than contiguous data and serialization is more hazardous. Use with caution!
1004+
Defaults to ``False``.
9961005
9971006
Examples:
9981007
>>> data = TensorDict({
@@ -1054,6 +1063,7 @@ def __init__(
10541063
ndim: int = 1,
10551064
compilable: bool = False,
10561065
consolidated: bool = False,
1066+
empty_lazy: bool = False,
10571067
):
10581068
super().__init__(
10591069
storage=None,
@@ -1062,11 +1072,13 @@ def __init__(
10621072
ndim=ndim,
10631073
compilable=compilable,
10641074
)
1075+
self.empty_lazy = empty_lazy
10651076
self.consolidated = consolidated
10661077

10671078
def _init(
10681079
self,
10691080
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
1081+
shape: torch.Size | None = None,
10701082
) -> None:
10711083
if not self._compilable:
10721084
# TODO: Investigate why this seems to have a performance impact with
@@ -1087,8 +1099,14 @@ def max_size_along_dim0(data_shape):
10871099

10881100
if is_tensor_collection(data):
10891101
out = data.to(self.device)
1090-
out: TensorDictBase = torch.empty_like(
1091-
out.expand(max_size_along_dim0(data.shape))
1102+
if self.empty_lazy and shape is None:
1103+
raise RuntimeError(
1104+
"Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1105+
)
1106+
elif shape is None:
1107+
shape = data.shape
1108+
out: TensorDictBase = out.new_empty(
1109+
max_size_along_dim0(shape), empty_lazy=self.empty_lazy
10921110
)
10931111
if self.consolidated:
10941112
out = out.consolidate()
@@ -1286,7 +1304,9 @@ def load_state_dict(self, state_dict):
12861304
self.initialized = state_dict["initialized"]
12871305
self._len = state_dict["_len"]
12881306

1289-
def _init(self, data: TensorDictBase | torch.Tensor) -> None:
1307+
def _init(
1308+
self, data: TensorDictBase | torch.Tensor, *, shape: torch.Size | None = None
1309+
) -> None:
12901310
torchrl_logger.debug("Creating a MemmapStorage...")
12911311
if self.device == "auto":
12921312
self.device = data.device
@@ -1304,8 +1324,14 @@ def max_size_along_dim0(data_shape):
13041324
return (self.max_size, *data_shape)
13051325

13061326
if is_tensor_collection(data):
1327+
if shape is None:
1328+
# Within add()
1329+
shape = data.shape
1330+
else:
1331+
# Get the first element - we don't care about empty_lazy in memmap storages
1332+
data = data[0]
13071333
out = data.clone().to(self.device)
1308-
out = out.expand(max_size_along_dim0(data.shape))
1334+
out = out.expand(max_size_along_dim0(shape))
13091335
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
13101336
if torchrl_logger.isEnabledFor(logging.DEBUG):
13111337
for key, tensor in sorted(

0 commit comments

Comments
 (0)