diff --git a/test/test_rb.py b/test/test_rb.py index f55d99b8c8b..8368aef58f9 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -870,7 +870,9 @@ def extend_and_sample(data): "`TensorStorage._rand_given_ndim` can be removed." ) - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize( + "storage_type", [partial(LazyTensorStorage, empty_lazy=True), LazyMemmapStorage] + ) def test_extend_lazystack(self, storage_type): rb = ReplayBuffer( @@ -881,9 +883,24 @@ def test_extend_lazystack(self, storage_type): td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5) ltd = LazyStackedTensorDict(td1, td2, stack_dim=1) rb.extend(ltd) - rb.sample(3) + s = rb.sample(3) + assert isinstance(s, LazyStackedTensorDict) assert len(rb) == 5 + def test_extend_empty_lazy(self): + + rb = ReplayBuffer( + storage=LazyTensorStorage(6, empty_lazy=True), + batch_size=2, + ) + td1 = TensorDict(a=torch.rand(4, 8), batch_size=4) + td2 = TensorDict(a=torch.rand(3, 8), batch_size=3) + ltd = LazyStackedTensorDict(td1, td2, stack_dim=0) + rb.extend(ltd) + s = rb.sample(3) + assert isinstance(s, LazyStackedTensorDict) + assert len(rb) == 2 + @pytest.mark.parametrize("device_data", get_default_devices()) @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index f839ef253d9..68799e13ba2 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -827,7 +827,7 @@ def set( if not self.initialized: if not isinstance(cursor, INT_CLASSES): if is_tensor_collection(data): - self._init(data[0]) + self._init(data, shape=data.shape[1:]) else: self._init(tree_map(lambda x: x[0], data)) else: @@ -873,7 +873,7 @@ def set( # noqa: F811 ) if not self.initialized: if not isinstance(cursor, INT_CLASSES): - self._init(data[0]) + self._init(data, shape=data.shape[1:]) else: self._init(data) if not isinstance(cursor, (*INT_CLASSES, slice)): @@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage): Defaults to ``False``. consolidated (bool, optional): if ``True``, the storage will be consolidated after its first expansion. Defaults to ``False``. + empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict + passed to the storage will be emptied of its content. This can be used to store + ragged data or content with exclusive keys (e.g., when some but not all environments + provide extra data to be stored in the buffer). + Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add` + will result in an exception). + Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be + slower than contiguous data and serialization is more hazardous. Use with caution! + Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -1054,6 +1063,7 @@ def __init__( ndim: int = 1, compilable: bool = False, consolidated: bool = False, + empty_lazy: bool = False, ): super().__init__( storage=None, @@ -1062,11 +1072,13 @@ def __init__( ndim=ndim, compilable=compilable, ) + self.empty_lazy = empty_lazy self.consolidated = consolidated def _init( self, data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 + shape: torch.Size | None = None, ) -> None: if not self._compilable: # TODO: Investigate why this seems to have a performance impact with @@ -1087,8 +1099,21 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out: TensorDictBase = torch.empty_like( - out.expand(max_size_along_dim0(data.shape)) + if self.empty_lazy: + if shape is None: + # shape is None in add + raise RuntimeError( + "Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`." + ) + out: TensorDictBase = torch.empty_like( + out.expand(max_size_along_dim0(data.shape)) + ) + elif shape is None: + shape = data.shape + else: + out = out[0] + out: TensorDictBase = out.new_empty( + max_size_along_dim0(shape), empty_lazy=self.empty_lazy ) if self.consolidated: out = out.consolidate() @@ -1286,7 +1311,9 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - def _init(self, data: TensorDictBase | torch.Tensor) -> None: + def _init( + self, data: TensorDictBase | torch.Tensor, *, shape: torch.Size | None = None + ) -> None: torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device @@ -1304,8 +1331,14 @@ def max_size_along_dim0(data_shape): return (self.max_size, *data_shape) if is_tensor_collection(data): + if shape is None: + # Within add() + shape = data.shape + else: + # Get the first element - we don't care about empty_lazy in memmap storages + data = data[0] out = data.clone().to(self.device) - out = out.expand(max_size_along_dim0(data.shape)) + out = out.expand(max_size_along_dim0(shape)) out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok) if torchrl_logger.isEnabledFor(logging.DEBUG): for key, tensor in sorted( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 19e2ad7ec7d..fb61ef15e13 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7308,12 +7308,13 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: return reward_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"] - if not time_dim: + try: + time_dim = list(tensordict.names).index("time") + except ValueError: raise ValueError( "At least one dimension of the tensordict must be named 'time' in offline mode" ) - time_dim = time_dim[0] - 1 + time_dim = time_dim - 1 for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): reward = tensordict[in_key] cumsum = reward.cumsum(time_dim)