Skip to content

Commit 605b4aa

Browse files
author
Vincent Moens
committed
[Feature] Prevent loading existing mmap files in storages if they already exist
ghstack-source-id: 63bcb1e Pull Request resolved: #2438
1 parent 36545af commit 605b4aa

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/test_rb.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,20 @@ def test_errors(self, storage_type):
546546
):
547547
storage_type(data, max_size=4)
548548

549+
def test_existsok_lazymemmap(self, tmpdir):
550+
storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir)
551+
rb = ReplayBuffer(storage=storage0)
552+
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
553+
554+
storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir)
555+
rb = ReplayBuffer(storage=storage1)
556+
with pytest.raises(RuntimeError, match="existsok"):
557+
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
558+
559+
storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True)
560+
rb = ReplayBuffer(storage=storage2)
561+
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
562+
549563
@pytest.mark.parametrize(
550564
"data_type", ["tensor", "tensordict", "tensorclass", "pytree"]
551565
)

torchrl/data/replay_buffers/storages.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,8 @@ class LazyMemmapStorage(LazyTensorStorage):
923923
Args:
924924
max_size (int): size of the storage, i.e. maximum number of elements stored
925925
in the buffer.
926+
927+
Keyword Args:
926928
scratch_dir (str or path): directory where memmap-tensors will be written.
927929
device (torch.device, optional): device where the sampled tensors will be
928930
stored and sent. Default is :obj:`torch.device("cpu")`.
@@ -933,6 +935,9 @@ class LazyMemmapStorage(LazyTensorStorage):
933935
measuring the storage size. For instance, a storage of shape ``[3, 4]``
934936
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
935937
Defaults to ``1``.
938+
existsok (bool, optional): whether an error should be raised if any of the
939+
tensors already exists on disk. Defaults to ``True``. If ``False``, the
940+
tensor will be opened as is, not overewritten.
936941
937942
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
938943
already stored to avoid executing long copies of data that is already stored on disk.
@@ -1009,10 +1014,12 @@ def __init__(
10091014
scratch_dir=None,
10101015
device: torch.device = "cpu",
10111016
ndim: int = 1,
1017+
existsok: bool = False,
10121018
):
10131019
super().__init__(max_size, ndim=ndim)
10141020
self.initialized = False
10151021
self.scratch_dir = None
1022+
self.existsok = existsok
10161023
if scratch_dir is not None:
10171024
self.scratch_dir = str(scratch_dir)
10181025
if self.scratch_dir[-1] != "/":
@@ -1108,7 +1115,7 @@ def max_size_along_dim0(data_shape):
11081115
if is_tensor_collection(data):
11091116
out = data.clone().to(self.device)
11101117
out = out.expand(max_size_along_dim0(data.shape))
1111-
out = out.memmap_like(prefix=self.scratch_dir)
1118+
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
11121119
for key, tensor in sorted(
11131120
out.items(include_nested=True, leaves_only=True), key=str
11141121
):

0 commit comments

Comments
 (0)