Skip to content

Commit 10735a7

Browse files
committed
[Feature] Avoid some recompiles of ReplayBuffer.extend\sample
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 02c3066 Pull Request resolved: #2504
1 parent 5244a90 commit 10735a7

File tree

4 files changed

+131
-3
lines changed

4 files changed

+131
-3
lines changed

test/_utils_internal.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import logging
89
import os
910

1011
import os.path
1112
import time
13+
import unittest
1214
from functools import wraps
1315

1416
# Get relative file path
@@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
204206
return deco_retry
205207

206208

209+
# After calling this function, any log record whose name contains 'record_name'
210+
# and is emitted from the logger that has qualified name 'logger_qname' is
211+
# appended to the 'records' list.
212+
# NOTE: This function is based on testing utilities for 'torch._logging'
213+
def capture_log_records(records, logger_qname, record_name):
214+
assert isinstance(records, list)
215+
logger = logging.getLogger(logger_qname)
216+
217+
class EmitWrapper:
218+
def __init__(self, old_emit):
219+
self.old_emit = old_emit
220+
221+
def __call__(self, record):
222+
nonlocal records
223+
self.old_emit(record)
224+
if record_name in record.name:
225+
records.append(record)
226+
227+
for handler in logger.handlers:
228+
new_emit = EmitWrapper(handler.emit)
229+
contextlib.ExitStack().enter_context(
230+
unittest.mock.patch.object(handler, "emit", new_emit)
231+
)
232+
233+
207234
@pytest.fixture
208235
def dtype_fixture():
209236
dtype = torch.get_default_dtype()

test/test_rb.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import pytest
1818
import torch
1919

20-
from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
20+
from _utils_internal import (
21+
capture_log_records,
22+
CARTPOLE_VERSIONED,
23+
get_default_devices,
24+
make_tc,
25+
)
2126

2227
from mocking_classes import CountingEnv
2328
from packaging import version
@@ -399,6 +404,73 @@ def data_iter():
399404
) if cond else contextlib.nullcontext():
400405
rb.extend(data2)
401406

407+
def test_extend_sample_recompile(
408+
self, rb_type, sampler, writer, storage, size, datatype
409+
):
410+
if _os_is_windows:
411+
# Compiling on Windows requires "cl" compiler to be installed.
412+
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
413+
# Our Windows CI jobs do not have "cl", so skip this test.
414+
pytest.skip("This test does not support Windows.")
415+
if rb_type is not ReplayBuffer:
416+
pytest.skip(
417+
"Only replay buffer of type 'ReplayBuffer' is currently supported."
418+
)
419+
if sampler is not RandomSampler:
420+
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
421+
if storage is not LazyTensorStorage:
422+
pytest.skip(
423+
"Only storage of type 'LazyTensorStorage' is currently supported."
424+
)
425+
if writer is not RoundRobinWriter:
426+
pytest.skip(
427+
"Only writer of type 'RoundRobinWriter' is currently supported."
428+
)
429+
if datatype == "tensordict":
430+
pytest.skip("'tensordict' datatype is not currently supported.")
431+
432+
torch._dynamo.reset_code_caches()
433+
434+
storage_size = 10 * size
435+
rb = self._get_rb(
436+
rb_type=rb_type,
437+
sampler=sampler,
438+
writer=writer,
439+
storage=storage,
440+
size=storage_size,
441+
)
442+
data_size = size
443+
data = self._get_data(datatype, size=data_size)
444+
445+
@torch.compile
446+
def extend_and_sample(data):
447+
rb.extend(data)
448+
return rb.sample()
449+
450+
# Number of times to extend the replay buffer
451+
num_extend = 30
452+
453+
# NOTE: The first two calls to 'extend' and 'sample' currently cause
454+
# recompilations, so avoid capturing those for now.
455+
num_extend_before_capture = 2
456+
457+
for _ in range(num_extend_before_capture):
458+
extend_and_sample(data)
459+
460+
try:
461+
torch._logging.set_logs(recompiles=True)
462+
records = []
463+
capture_log_records(records, "torch._dynamo", "recompiles")
464+
465+
for _ in range(num_extend - num_extend_before_capture):
466+
extend_and_sample(data)
467+
468+
assert len(rb) == storage_size
469+
assert len(records) == 0
470+
471+
finally:
472+
torch._logging.set_logs()
473+
402474
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
403475
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
404476
pytest.skip(

test/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616

17-
from _utils_internal import get_default_devices
17+
from _utils_internal import capture_log_records, get_default_devices
1818
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
1919

2020
from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
@@ -380,6 +380,31 @@ def test_rng_decorator(device):
380380
torch.testing.assert_close(s0b, s1b)
381381

382382

383+
# Check that 'capture_log_records' captures records emitted when torch
384+
# recompiles a function.
385+
def test_capture_log_records_recompile():
386+
torch.compiler.reset()
387+
388+
# This function recompiles each time it is called with a different string
389+
# input.
390+
@torch.compile
391+
def str_to_tensor(s):
392+
return bytes(s, "utf8")
393+
394+
str_to_tensor("a")
395+
396+
try:
397+
torch._logging.set_logs(recompiles=True)
398+
records = []
399+
capture_log_records(records, "torch._dynamo", "recompiles")
400+
str_to_tensor("b")
401+
402+
finally:
403+
torch._logging.set_logs()
404+
405+
assert len(records) == 1
406+
407+
383408
if __name__ == "__main__":
384409
args, unknown = argparse.ArgumentParser().parse_known_args()
385410
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/storages.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,14 @@ def _empty(self):
146146

147147
def _rand_given_ndim(self, batch_size):
148148
# a method to return random indices given the storage ndim
149+
if isinstance(self, TensorStorage):
150+
storage_len = self._len
151+
else:
152+
storage_len = len(self)
149153
if self.ndim == 1:
150154
return torch.randint(
151155
0,
152-
len(self),
156+
storage_len,
153157
(batch_size,),
154158
generator=self._rng,
155159
device=getattr(self, "device", None),

0 commit comments

Comments
 (0)