@@ -827,7 +827,7 @@ def set(
827
827
if not self .initialized :
828
828
if not isinstance (cursor , INT_CLASSES ):
829
829
if is_tensor_collection (data ):
830
- self ._init (data [ 0 ])
830
+ self ._init (data , shape = data . shape [ 1 : ])
831
831
else :
832
832
self ._init (tree_map (lambda x : x [0 ], data ))
833
833
else :
@@ -873,7 +873,7 @@ def set( # noqa: F811
873
873
)
874
874
if not self .initialized :
875
875
if not isinstance (cursor , INT_CLASSES ):
876
- self ._init (data [ 0 ])
876
+ self ._init (data , shape = data . shape [ 1 : ])
877
877
else :
878
878
self ._init (data )
879
879
if not isinstance (cursor , (* INT_CLASSES , slice )):
@@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
993
993
Defaults to ``False``.
994
994
consolidated (bool, optional): if ``True``, the storage will be consolidated after
995
995
its first expansion. Defaults to ``False``.
996
+ empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
997
+ passed to the storage will be emptied of its content. This can be used to store
998
+ ragged data or content with exclusive keys (e.g., when some but not all environments
999
+ provide extra data to be stored in the buffer).
1000
+ Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
1001
+ will result in an exception).
1002
+ Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
1003
+ slower than contiguous data and serialization is more hazardous. Use with caution!
1004
+ Defaults to ``False``.
996
1005
997
1006
Examples:
998
1007
>>> data = TensorDict({
@@ -1054,6 +1063,7 @@ def __init__(
1054
1063
ndim : int = 1 ,
1055
1064
compilable : bool = False ,
1056
1065
consolidated : bool = False ,
1066
+ empty_lazy : bool = False ,
1057
1067
):
1058
1068
super ().__init__ (
1059
1069
storage = None ,
@@ -1062,11 +1072,13 @@ def __init__(
1062
1072
ndim = ndim ,
1063
1073
compilable = compilable ,
1064
1074
)
1075
+ self .empty_lazy = empty_lazy
1065
1076
self .consolidated = consolidated
1066
1077
1067
1078
def _init (
1068
1079
self ,
1069
1080
data : TensorDictBase | torch .Tensor | PyTree , # noqa: F821
1081
+ shape : torch .Size | None = None ,
1070
1082
) -> None :
1071
1083
if not self ._compilable :
1072
1084
# TODO: Investigate why this seems to have a performance impact with
@@ -1087,8 +1099,14 @@ def max_size_along_dim0(data_shape):
1087
1099
1088
1100
if is_tensor_collection (data ):
1089
1101
out = data .to (self .device )
1090
- out : TensorDictBase = torch .empty_like (
1091
- out .expand (max_size_along_dim0 (data .shape ))
1102
+ if self .empty_lazy and shape is None :
1103
+ raise RuntimeError (
1104
+ "Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1105
+ )
1106
+ elif shape is None :
1107
+ shape = data .shape
1108
+ out : TensorDictBase = out .new_empty (
1109
+ max_size_along_dim0 (shape ), empty_lazy = self .empty_lazy
1092
1110
)
1093
1111
if self .consolidated :
1094
1112
out = out .consolidate ()
@@ -1286,7 +1304,9 @@ def load_state_dict(self, state_dict):
1286
1304
self .initialized = state_dict ["initialized" ]
1287
1305
self ._len = state_dict ["_len" ]
1288
1306
1289
- def _init (self , data : TensorDictBase | torch .Tensor ) -> None :
1307
+ def _init (
1308
+ self , data : TensorDictBase | torch .Tensor , * , shape : torch .Size | None = None
1309
+ ) -> None :
1290
1310
torchrl_logger .debug ("Creating a MemmapStorage..." )
1291
1311
if self .device == "auto" :
1292
1312
self .device = data .device
@@ -1304,8 +1324,14 @@ def max_size_along_dim0(data_shape):
1304
1324
return (self .max_size , * data_shape )
1305
1325
1306
1326
if is_tensor_collection (data ):
1327
+ if shape is None :
1328
+ # Within add()
1329
+ shape = data .shape
1330
+ else :
1331
+ # Get the first element - we don't care about empty_lazy in memmap storages
1332
+ data = data [0 ]
1307
1333
out = data .clone ().to (self .device )
1308
- out = out .expand (max_size_along_dim0 (data . shape ))
1334
+ out = out .expand (max_size_along_dim0 (shape ))
1309
1335
out = out .memmap_like (prefix = self .scratch_dir , existsok = self .existsok )
1310
1336
if torchrl_logger .isEnabledFor (logging .DEBUG ):
1311
1337
for key , tensor in sorted (
0 commit comments