Skip to content

Commit 08cb766

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6750343 commit 08cb766

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

test/test_rb.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
_has_gym = importlib.util.find_spec("gym") is not None
117117
_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
118118
_os_is_windows = sys.platform == "win32"
119+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
119120

120121
torch_2_3 = version.parse(
121122
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
@@ -404,14 +405,16 @@ def data_iter():
404405
) if cond else contextlib.nullcontext():
405406
rb.extend(data2)
406407

408+
@pytest.mark.skipif(
409+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
410+
)
411+
# Compiling on Windows requires "cl" compiler to be installed.
412+
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
413+
# Our Windows CI jobs do not have "cl", so skip this test.
414+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
407415
def test_extend_sample_recompile(
408416
self, rb_type, sampler, writer, storage, size, datatype
409417
):
410-
if _os_is_windows:
411-
# Compiling on Windows requires "cl" compiler to be installed.
412-
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
413-
# Our Windows CI jobs do not have "cl", so skip this test.
414-
pytest.skip("This test does not support Windows.")
415418
if rb_type is not ReplayBuffer:
416419
pytest.skip(
417420
"Only replay buffer of type 'ReplayBuffer' is currently supported."
@@ -429,7 +432,7 @@ def test_extend_sample_recompile(
429432
if datatype == "tensordict":
430433
pytest.skip("'tensordict' datatype is not currently supported.")
431434

432-
torch._dynamo.reset()
435+
torch._dynamo.reset_code_caches()
433436

434437
storage_size = 10 * size
435438
rb = self._get_rb(

torchrl/data/replay_buffers/storages.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
144144
def _empty(self):
145145
...
146146

147-
# NOTE: This property is used to enable compiled Storages. A `len(self)`
148-
# call can cause recompiles, but for some reason, wrapping the call in a
149-
# `property` decorated function avoids the recompiles.
147+
# NOTE: This property is used to enable compiled Storages. Calling
148+
# `len(self)` on a TensorStorage should normally cause a graph break since
149+
# it uses a `mp.Value`, and it does cause a break when the `len(self)` call
150+
# happens within a method of TensorStorage itself. However, when the
151+
# `len(self)` call happens in the Storage base class, for an unknown reason
152+
# the compiler doesn't seem to recognize that there should be a graph break,
153+
# and the lack of a break causes a recompile each time `len(self)` is called
154+
# in this context. Also for an unknown reason, we can force the graph break
155+
# to happen if we wrap the `len(self)` call with a `property`-decorated
156+
# function. For another unknown reason, if we change
157+
# `TensorStorage._len_value` from `mp.Value` to int, it seems like there
158+
# should no longer be any need to recompile, but recompiles happen anyway.
159+
# Ideally, this should all be investigated and understood in the future.
150160
@property
151161
def len(self):
152162
return len(self)

0 commit comments

Comments
 (0)