Skip to content

Commit 98d2933

Browse files
committed
[CI] Add benchmarks to test runs
ghstack-source-id: 8d83ae8 Pull Request resolved: #2410
1 parent 0a410ff commit 98d2933

File tree

6 files changed

+31
-16
lines changed

6 files changed

+31
-16
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- pytest-cov
1616
- pytest-mock
1717
- pytest-instafail
18+
- pytest-benchmark
1819
- pytest-rerunfailures
1920
- pytest-timeout
2021
- expecttest

.github/unittest/linux/scripts/run_all.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ conda deactivate
8888
conda activate "${env_dir}"
8989

9090
echo "installing gymnasium"
91-
pip3 install "gymnasium"
92-
pip3 install ale_py
91+
pip3 install "gymnasium[atari,accept-rom-license]"
9392
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
9493
pip3 install "mujoco" -U
9594

@@ -189,9 +188,14 @@ export MKL_THREADING_LAYER=GNU
189188
export CKPT_BACKEND=torch
190189
export MAX_IDLE_COUNT=100
191190
export BATCHED_PIPE_TIMEOUT=60
191+
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
192192

193193
pytest test/smoke_test.py -v --durations 200
194194
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
195+
196+
# Check that benchmarks run
197+
python -m pytest benchmarks
198+
195199
if [ "${CU_VERSION:-}" != cpu ] ; then
196200
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
197201
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ do
126126
conda activate ./cloned_env
127127

128128
echo "Testing gym version: ${GYM_VERSION}"
129-
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION
129+
pip3 install 'gymnasium[atari,accept-rom-license]'==$GYM_VERSION
130130

131131
$DIR/run_test.sh
132132

torchrl/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
from ._extension import _init_extension
2323

24+
try:
25+
from torch.compiler import is_dynamo_compiling
26+
except Exception:
27+
from torch._dynamo import is_compiling as is_dynamo_compiling
2428

2529
try:
2630
from .version import __version__
@@ -69,7 +73,7 @@ def _inv(self):
6973
inv = self._inv()
7074
if inv is None:
7175
inv = _InverseTransform(self)
72-
if not torch.compiler.is_dynamo_compiling():
76+
if not is_dynamo_compiling():
7377
self._inv = weakref.ref(inv)
7478
return inv
7579

@@ -84,7 +88,7 @@ def _inv(self):
8488
inv = self._inv()
8589
if inv is None:
8690
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
87-
if not torch.compiler.is_dynamo_compiling():
91+
if not is_dynamo_compiling():
8892
self._inv = weakref.ref(inv)
8993
inv._inv = weakref.ref(self)
9094
else:

torchrl/modules/distributions/continuous.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
)
3434
from torchrl.modules.utils import mappings
3535

36+
try:
37+
from torch.compiler import is_dynamo_compiling
38+
except Exception:
39+
from torch._dynamo import is_compiling as is_dynamo_compiling
40+
3641
# speeds up distribution construction
3742
D.Distribution.set_default_validate_args(False)
3843

@@ -112,7 +117,7 @@ def inv(self):
112117
inv = self._inv()
113118
if inv is None:
114119
inv = _InverseTransform(self)
115-
if not torch.compiler.is_dynamo_compiling():
120+
if not is_dynamo_compiling():
116121
self._inv = weakref.ref(inv)
117122
return inv
118123

@@ -334,7 +339,7 @@ def inv(self):
334339
inv = self._inv()
335340
if inv is None:
336341
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
337-
if not torch.compiler.is_dynamo_compiling():
342+
if not is_dynamo_compiling():
338343
self._inv = weakref.ref(inv)
339344
inv._inv = weakref.ref(self)
340345
return inv
@@ -348,7 +353,7 @@ def inv(self):
348353
inv = self._inv()
349354
if inv is None:
350355
inv = _InverseTransform(self)
351-
if not torch.compiler.is_dynamo_compiling():
356+
if not is_dynamo_compiling():
352357
self._inv = weakref.ref(inv)
353358
return inv
354359

@@ -460,15 +465,13 @@ def __init__(
460465
self.high = high
461466

462467
if safe_tanh:
463-
if torch.compiler.is_dynamo_compiling():
468+
if is_dynamo_compiling():
464469
_err_compile_safetanh()
465470
t = SafeTanhTransform()
466471
else:
467472
t = D.TanhTransform()
468473
# t = D.TanhTransform()
469-
if torch.compiler.is_dynamo_compiling() or (
470-
self.non_trivial_max or self.non_trivial_min
471-
):
474+
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
472475
t = _PatchedComposeTransform(
473476
[
474477
t,
@@ -495,9 +498,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
495498
if self.tanh_loc:
496499
loc = (loc / self.upscale).tanh() * self.upscale
497500
# loc must be rescaled if tanh_loc
498-
if torch.compiler.is_dynamo_compiling() or (
499-
self.non_trivial_max or self.non_trivial_min
500-
):
501+
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
501502
loc = loc + (self.high - self.low) / 2 + self.low
502503
self.loc = loc
503504
self.scale = scale

torchrl/objectives/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
raise err_ft from err
2727
from torchrl.envs.utils import step_mdp
2828

29+
try:
30+
from torch.compiler import is_dynamo_compiling
31+
except Exception:
32+
from torch._dynamo import is_compiling as is_dynamo_compiling
33+
2934
_GAMMA_LMBDA_DEPREC_ERROR = (
3035
"Passing gamma / lambda parameters through the loss constructor "
3136
"is a deprecated feature. To customize your value function, "
@@ -460,7 +465,7 @@ def _cache_values(func):
460465

461466
@functools.wraps(func)
462467
def new_func(self, netname=None):
463-
if torch.compiler.is_dynamo_compiling():
468+
if is_dynamo_compiling():
464469
if netname is not None:
465470
return func(self, netname)
466471
else:

0 commit comments

Comments
 (0)