Skip to content

Commit 072f321

Browse files
author
Vincent Moens
committed
[Feature] replay_buffer_chunk
ghstack-source-id: 1edfde0 Pull Request resolved: #2388
1 parent 2b70284 commit 072f321

File tree

6 files changed

+126
-23
lines changed

6 files changed

+126
-23
lines changed

test/test_collector.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchrl.collectors.utils import split_trajectories
7070
from torchrl.data import (
7171
Composite,
72+
LazyMemmapStorage,
7273
LazyTensorStorage,
7374
NonTensor,
7475
ReplayBuffer,
@@ -2799,44 +2800,86 @@ def test_collector_rb_sync(self):
27992800
del collector, env
28002801
assert assert_allclose_td(rbdata0, rbdata1)
28012802

2802-
def test_collector_rb_multisync(self):
2803-
env = GymEnv(CARTPOLE_VERSIONED())
2804-
env.set_seed(0)
2803+
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
2804+
@pytest.mark.parametrize("env_creator", [False, True])
2805+
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
2806+
def test_collector_rb_multisync(
2807+
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
2808+
):
2809+
if not env_creator:
2810+
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
2811+
env.set_seed(0)
2812+
action_spec = env.action_spec
2813+
env = lambda env=env: env
2814+
else:
2815+
env = EnvCreator(
2816+
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
2817+
StepCounter()
2818+
)
2819+
)
2820+
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
28052821

2806-
rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
2807-
rb.add(env.rand_step(env.reset()))
2808-
rb.empty()
2822+
if storagetype == LazyMemmapStorage:
2823+
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
2824+
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
28092825

28102826
collector = MultiSyncDataCollector(
2811-
[lambda: env, lambda: env],
2812-
RandomPolicy(env.action_spec),
2827+
[env, env],
2828+
RandomPolicy(action_spec),
28132829
replay_buffer=rb,
28142830
total_frames=256,
2815-
frames_per_batch=16,
2831+
frames_per_batch=32,
2832+
replay_buffer_chunk=replay_buffer_chunk,
28162833
)
28172834
torch.manual_seed(0)
28182835
pred_len = 0
28192836
for c in collector:
2820-
pred_len += 16
2837+
pred_len += 32
28212838
assert c is None
28222839
assert len(rb) == pred_len
28232840
collector.shutdown()
28242841
assert len(rb) == 256
2842+
if not replay_buffer_chunk:
2843+
steps_counts = rb["step_count"].squeeze().split(16)
2844+
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
2845+
for step_count, ids in zip(steps_counts, collector_ids):
2846+
step_countdiff = step_count.diff()
2847+
idsdiff = ids.diff()
2848+
assert (
2849+
(step_countdiff == 1) | (step_countdiff < 0)
2850+
).all(), steps_counts
2851+
assert (idsdiff >= 0).all()
2852+
2853+
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
2854+
@pytest.mark.parametrize("env_creator", [False, True])
2855+
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
2856+
def test_collector_rb_multiasync(
2857+
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
2858+
):
2859+
if not env_creator:
2860+
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
2861+
env.set_seed(0)
2862+
action_spec = env.action_spec
2863+
env = lambda env=env: env
2864+
else:
2865+
env = EnvCreator(
2866+
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
2867+
StepCounter()
2868+
)
2869+
)
2870+
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
28252871

2826-
def test_collector_rb_multiasync(self):
2827-
env = GymEnv(CARTPOLE_VERSIONED())
2828-
env.set_seed(0)
2829-
2830-
rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
2831-
rb.add(env.rand_step(env.reset()))
2832-
rb.empty()
2872+
if storagetype == LazyMemmapStorage:
2873+
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
2874+
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
28332875

28342876
collector = MultiaSyncDataCollector(
2835-
[lambda: env, lambda: env],
2836-
RandomPolicy(env.action_spec),
2877+
[env, env],
2878+
RandomPolicy(action_spec),
28372879
replay_buffer=rb,
28382880
total_frames=256,
28392881
frames_per_batch=16,
2882+
replay_buffer_chunk=replay_buffer_chunk,
28402883
)
28412884
torch.manual_seed(0)
28422885
pred_len = 0
@@ -2846,6 +2889,16 @@ def test_collector_rb_multiasync(self):
28462889
assert len(rb) >= pred_len
28472890
collector.shutdown()
28482891
assert len(rb) == 256
2892+
if not replay_buffer_chunk:
2893+
steps_counts = rb["step_count"].squeeze().split(16)
2894+
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
2895+
for step_count, ids in zip(steps_counts, collector_ids):
2896+
step_countdiff = step_count.diff()
2897+
idsdiff = ids.diff()
2898+
assert (
2899+
(step_countdiff == 1) | (step_countdiff < 0)
2900+
).all(), steps_counts
2901+
assert (idsdiff >= 0).all()
28492902

28502903

28512904
if __name__ == "__main__":

torchrl/collectors/collectors.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from torchrl.data.tensor_specs import TensorSpec
5555
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
5656
from torchrl.envs.common import _do_nothing, EnvBase
57+
from torchrl.envs.env_creator import EnvCreator
5758
from torchrl.envs.transforms import StepCounter, TransformedEnv
5859
from torchrl.envs.utils import (
5960
_aggregate_end_of_traj,
@@ -1469,6 +1470,7 @@ def __init__(
14691470
set_truncated: bool = False,
14701471
use_buffers: bool | None = None,
14711472
replay_buffer: ReplayBuffer | None = None,
1473+
replay_buffer_chunk: bool = True,
14721474
):
14731475
exploration_type = _convert_exploration_type(
14741476
exploration_mode=exploration_mode, exploration_type=exploration_type
@@ -1513,6 +1515,8 @@ def __init__(
15131515

15141516
self._use_buffers = use_buffers
15151517
self.replay_buffer = replay_buffer
1518+
self._check_replay_buffer_init()
1519+
self.replay_buffer_chunk = replay_buffer_chunk
15161520
if (
15171521
replay_buffer is not None
15181522
and hasattr(replay_buffer, "shared")
@@ -1659,6 +1663,21 @@ def _get_weight_fn(weights=policy_weights):
16591663
)
16601664
self.cat_results = cat_results
16611665

1666+
def _check_replay_buffer_init(self):
1667+
try:
1668+
if not self.replay_buffer._storage.initialized:
1669+
if isinstance(self.create_env_fn, EnvCreator):
1670+
fake_td = self.create_env_fn.tensordict
1671+
else:
1672+
fake_td = self.create_env_fn[0](
1673+
**self.create_env_kwargs[0]
1674+
).fake_tensordict()
1675+
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)
1676+
1677+
self.replay_buffer._storage._init(fake_td)
1678+
except AttributeError:
1679+
pass
1680+
16621681
@classmethod
16631682
def _total_workers_from_env(cls, env_creators):
16641683
if isinstance(env_creators, (tuple, list)):
@@ -1793,6 +1812,7 @@ def _run_processes(self) -> None:
17931812
"set_truncated": self.set_truncated,
17941813
"use_buffers": self._use_buffers,
17951814
"replay_buffer": self.replay_buffer,
1815+
"replay_buffer_chunk": self.replay_buffer_chunk,
17961816
"traj_pool": traj_pool,
17971817
}
17981818
proc = _ProcessNoWarn(
@@ -2802,6 +2822,7 @@ def _main_async_collector(
28022822
set_truncated: bool = False,
28032823
use_buffers: bool | None = None,
28042824
replay_buffer: ReplayBuffer | None = None,
2825+
replay_buffer_chunk: bool = True,
28052826
traj_pool: _TrajectoryPool = None,
28062827
) -> None:
28072828
pipe_parent.close()
@@ -2823,11 +2844,11 @@ def _main_async_collector(
28232844
env_device=env_device,
28242845
exploration_type=exploration_type,
28252846
reset_when_done=reset_when_done,
2826-
return_same_td=True,
2847+
return_same_td=replay_buffer is None,
28272848
interruptor=interruptor,
28282849
set_truncated=set_truncated,
28292850
use_buffers=use_buffers,
2830-
replay_buffer=replay_buffer,
2851+
replay_buffer=replay_buffer if replay_buffer_chunk else None,
28312852
traj_pool=traj_pool,
28322853
)
28332854
use_buffers = inner_collector._use_buffers
@@ -2893,6 +2914,10 @@ def _main_async_collector(
28932914
continue
28942915

28952916
if replay_buffer is not None:
2917+
if not replay_buffer_chunk:
2918+
next_data.names = None
2919+
replay_buffer.extend(next_data)
2920+
28962921
try:
28972922
queue_out.put((idx, j), timeout=_TIMEOUT)
28982923
if verbose:

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ def __len__(self) -> int:
364364
with self._replay_lock:
365365
return len(self._storage)
366366

367+
@property
368+
def write_count(self):
369+
"""The total number of items written so far in the buffer through add and extend."""
370+
return self._writer._write_count
371+
367372
def __repr__(self) -> str:
368373
from torchrl.envs.transforms import Compose
369374

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def assert_is_sharable(tensor):
562562
raise RuntimeError(STORAGE_ERR)
563563

564564
if is_tensor_collection(storage):
565-
storage.apply(assert_is_sharable)
565+
storage.apply(assert_is_sharable, filter_empty=True)
566566
else:
567567
tree_map(storage, assert_is_sharable)
568568

torchrl/data/replay_buffers/writers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
163163
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
164164
single_data=data
165165
)
166+
self._write_count += 1
166167
# Replicate index requires the shape of the storage to be known
167168
# Other than that, a "flat" (1d) index is ok to write the data
168169
self._storage.set(_cursor, data)
@@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
191192
)
192193
# we need to update the cursor first to avoid race conditions between workers
193194
self._cursor = (batch_size + cur_size) % max_size_along0
195+
self._write_count += batch_size
194196
# Replicate index requires the shape of the storage to be known
195197
# Other than that, a "flat" (1d) index is ok to write the data
196198
self._storage.set(index, data)
@@ -222,6 +224,20 @@ def _cursor(self, value):
222224
_cursor_value = self._cursor_value = mp.Value("i", 0)
223225
_cursor_value.value = value
224226

227+
@property
228+
def _write_count(self):
229+
_write_count = self.__dict__.get("_write_count_value", None)
230+
if _write_count is None:
231+
_write_count = self._write_count_value = mp.Value("i", 0)
232+
return _write_count.value
233+
234+
@_write_count.setter
235+
def _write_count(self, value):
236+
_write_count = self.__dict__.get("_write_count_value", None)
237+
if _write_count is None:
238+
_write_count = self._write_count_value = mp.Value("i", 0)
239+
_write_count.value = value
240+
225241
def __getstate__(self):
226242
state = super().__getstate__()
227243
if get_spawning_popen() is None:
@@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
249265
# we need to update the cursor first to avoid race conditions between workers
250266
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
251267
self._cursor = (index + 1) % max_size_along_dim0
268+
self._write_count += 1
252269
if not is_tensorclass(data):
253270
data.set(
254271
"index",
@@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
275292
)
276293
# we need to update the cursor first to avoid race conditions between workers
277294
self._cursor = (batch_size + cur_size) % max_size_along_dim0
295+
self._write_count += batch_size
278296
# storage must convert the data to the appropriate format if needed
279297
if not is_tensorclass(data):
280298
data.set(
@@ -469,6 +487,7 @@ def add(self, data: Any) -> int | torch.Tensor:
469487
index = self.get_insert_index(data)
470488
if index is not None:
471489
data.set("index", index)
490+
self._write_count += 1
472491
# Replicate index requires the shape of the storage to be known
473492
# Other than that, a "flat" (1d) index is ok to write the data
474493
self._storage.set(index, data)
@@ -488,6 +507,7 @@ def extend(self, data: TensorDictBase) -> None:
488507
for data_idx, sample in enumerate(data):
489508
storage_idx = self.get_insert_index(sample)
490509
if storage_idx is not None:
510+
self._write_count += 1
491511
data_to_replace[storage_idx] = data_idx
492512

493513
# -1 will be interpreted as invalid by prioritized buffers

torchrl/envs/env_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
109109
del state_dict[key]
110110

111111
@property
112-
def meta_data(self):
112+
def meta_data(self) -> EnvMetaData:
113113
if self._meta_data is None:
114114
raise RuntimeError(
115115
"meta_data is None in EnvCreator. " "Make sure init_() has been called."

0 commit comments

Comments
 (0)