Skip to content

Commit 0f29c7e

Browse files
kurtamohlervmoens
authored andcommitted
[Feature] Avoid some recompiles of ReplayBuffer.extend/sample
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: d306cb9 Pull Request resolved: #2504
1 parent 5244a90 commit 0f29c7e

File tree

4 files changed

+153
-3
lines changed

4 files changed

+153
-3
lines changed

test/_utils_internal.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import logging
89
import os
910

1011
import os.path
1112
import time
13+
import unittest
1214
from functools import wraps
1315

1416
# Get relative file path
@@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
204206
return deco_retry
205207

206208

209+
# After calling this function, any log record whose name contains 'record_name'
210+
# and is emitted from the logger that has qualified name 'logger_qname' is
211+
# appended to the 'records' list.
212+
# NOTE: This function is based on testing utilities for 'torch._logging'
213+
def capture_log_records(records, logger_qname, record_name):
214+
assert isinstance(records, list)
215+
logger = logging.getLogger(logger_qname)
216+
217+
class EmitWrapper:
218+
def __init__(self, old_emit):
219+
self.old_emit = old_emit
220+
221+
def __call__(self, record):
222+
nonlocal records
223+
self.old_emit(record)
224+
if record_name in record.name:
225+
records.append(record)
226+
227+
for handler in logger.handlers:
228+
new_emit = EmitWrapper(handler.emit)
229+
contextlib.ExitStack().enter_context(
230+
unittest.mock.patch.object(handler, "emit", new_emit)
231+
)
232+
233+
207234
@pytest.fixture
208235
def dtype_fixture():
209236
dtype = torch.get_default_dtype()

test/test_rb.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import pytest
1818
import torch
1919

20-
from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
20+
from _utils_internal import (
21+
capture_log_records,
22+
CARTPOLE_VERSIONED,
23+
get_default_devices,
24+
make_tc,
25+
)
2126

2227
from mocking_classes import CountingEnv
2328
from packaging import version
@@ -111,6 +116,7 @@
111116
_has_gym = importlib.util.find_spec("gym") is not None
112117
_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
113118
_os_is_windows = sys.platform == "win32"
119+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
114120

115121
torch_2_3 = version.parse(
116122
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
@@ -399,6 +405,75 @@ def data_iter():
399405
) if cond else contextlib.nullcontext():
400406
rb.extend(data2)
401407

408+
@pytest.mark.skipif(
409+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
410+
)
411+
# Compiling on Windows requires "cl" compiler to be installed.
412+
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
413+
# Our Windows CI jobs do not have "cl", so skip this test.
414+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
415+
def test_extend_sample_recompile(
416+
self, rb_type, sampler, writer, storage, size, datatype
417+
):
418+
if rb_type is not ReplayBuffer:
419+
pytest.skip(
420+
"Only replay buffer of type 'ReplayBuffer' is currently supported."
421+
)
422+
if sampler is not RandomSampler:
423+
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
424+
if storage is not LazyTensorStorage:
425+
pytest.skip(
426+
"Only storage of type 'LazyTensorStorage' is currently supported."
427+
)
428+
if writer is not RoundRobinWriter:
429+
pytest.skip(
430+
"Only writer of type 'RoundRobinWriter' is currently supported."
431+
)
432+
if datatype == "tensordict":
433+
pytest.skip("'tensordict' datatype is not currently supported.")
434+
435+
torch._dynamo.reset_code_caches()
436+
437+
storage_size = 10 * size
438+
rb = self._get_rb(
439+
rb_type=rb_type,
440+
sampler=sampler,
441+
writer=writer,
442+
storage=storage,
443+
size=storage_size,
444+
)
445+
data_size = size
446+
data = self._get_data(datatype, size=data_size)
447+
448+
@torch.compile
449+
def extend_and_sample(data):
450+
rb.extend(data)
451+
return rb.sample()
452+
453+
# Number of times to extend the replay buffer
454+
num_extend = 30
455+
456+
# NOTE: The first two calls to 'extend' and 'sample' currently cause
457+
# recompilations, so avoid capturing those for now.
458+
num_extend_before_capture = 2
459+
460+
for _ in range(num_extend_before_capture):
461+
extend_and_sample(data)
462+
463+
try:
464+
torch._logging.set_logs(recompiles=True)
465+
records = []
466+
capture_log_records(records, "torch._dynamo", "recompiles")
467+
468+
for _ in range(num_extend - num_extend_before_capture):
469+
extend_and_sample(data)
470+
471+
assert len(rb) == storage_size
472+
assert len(records) == 0
473+
474+
finally:
475+
torch._logging.set_logs()
476+
402477
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
403478
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
404479
pytest.skip(

test/test_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
import torch
1616

17-
from _utils_internal import get_default_devices
17+
from _utils_internal import capture_log_records, get_default_devices
18+
from packaging import version
1819
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
1920

2021
from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
2122

23+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
24+
2225

2326
@pytest.mark.parametrize("value", ["True", "1", "true"])
2427
def test_get_binary_env_var_positive(value):
@@ -380,6 +383,34 @@ def test_rng_decorator(device):
380383
torch.testing.assert_close(s0b, s1b)
381384

382385

386+
# Check that 'capture_log_records' captures records emitted when torch
387+
# recompiles a function.
388+
@pytest.mark.skipif(
389+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
390+
)
391+
def test_capture_log_records_recompile():
392+
torch.compiler.reset()
393+
394+
# This function recompiles each time it is called with a different string
395+
# input.
396+
@torch.compile
397+
def str_to_tensor(s):
398+
return bytes(s, "utf8")
399+
400+
str_to_tensor("a")
401+
402+
try:
403+
torch._logging.set_logs(recompiles=True)
404+
records = []
405+
capture_log_records(records, "torch._dynamo", "recompiles")
406+
str_to_tensor("b")
407+
408+
finally:
409+
torch._logging.set_logs()
410+
411+
assert len(records) == 1
412+
413+
383414
if __name__ == "__main__":
384415
args, unknown = argparse.ArgumentParser().parse_known_args()
385416
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/storages.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
144144
def _empty(self):
145145
...
146146

147+
# NOTE: This property is used to enable compiled Storages. Calling
148+
# `len(self)` on a TensorStorage should normally cause a graph break since
149+
# it uses a `mp.Value`, and it does cause a break when the `len(self)` call
150+
# happens within a method of TensorStorage itself. However, when the
151+
# `len(self)` call happens in the Storage base class, for an unknown reason
152+
# the compiler doesn't seem to recognize that there should be a graph break,
153+
# and the lack of a break causes a recompile each time `len(self)` is called
154+
# in this context. Also for an unknown reason, we can force the graph break
155+
# to happen if we wrap the `len(self)` call with a `property`-decorated
156+
# function. For another unknown reason, if we change
157+
# `TensorStorage._len_value` from `mp.Value` to int, it seems like there
158+
# should no longer be any need to recompile, but recompiles happen anyway.
159+
# Ideally, this should all be investigated and understood in the future.
160+
@property
161+
def len(self):
162+
return len(self)
163+
147164
def _rand_given_ndim(self, batch_size):
148165
# a method to return random indices given the storage ndim
149166
if self.ndim == 1:
150167
return torch.randint(
151168
0,
152-
len(self),
169+
self.len,
153170
(batch_size,),
154171
generator=self._rng,
155172
device=getattr(self, "device", None),

0 commit comments

Comments
 (0)