Skip to content

Commit db01a8e

Browse files
author
Vincent Moens
committed
[Feature] Make benchmarked losses compatible with torch.compile
ghstack-source-id: fcdc6e7 Pull Request resolved: #2405
1 parent e82a69f commit db01a8e

File tree

13 files changed

+522
-158
lines changed

13 files changed

+522
-158
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 209 additions & 17 deletions
Large diffs are not rendered by default.

torchrl/envs/transforms/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@
3939
unravel_key,
4040
unravel_key_list,
4141
)
42-
from tensordict._C import _unravel_key_to_tuple
4342
from tensordict.nn import dispatch, TensorDictModuleBase
44-
from tensordict.utils import expand_as_right, expand_right, NestedKey
43+
from tensordict.utils import (
44+
_unravel_key_to_tuple,
45+
expand_as_right,
46+
expand_right,
47+
NestedKey,
48+
)
4549
from torch import nn, Tensor
4650
from torch.utils._pytree import tree_map
4751

torchrl/modules/distributions/continuous.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from __future__ import annotations
66

77
import warnings
8+
import weakref
89
from numbers import Number
910
from typing import Dict, Optional, Sequence, Tuple, Union
1011

1112
import numpy as np
1213
import torch
1314
from torch import distributions as D, nn
15+
from torch.compiler import assume_constant_result
1416
from torch.distributions import constraints
17+
from torch.distributions.transforms import _InverseTransform
1518

1619
from torchrl.modules.distributions.truncated_normal import (
1720
TruncatedNormal as _TruncatedNormal,
@@ -20,8 +23,8 @@
2023
from torchrl.modules.distributions.utils import (
2124
_cast_device,
2225
FasterTransformedDistribution,
23-
safeatanh,
24-
safetanh,
26+
safeatanh_noeps,
27+
safetanh_noeps,
2528
)
2629
from torchrl.modules.utils import mappings
2730

@@ -92,19 +95,21 @@ class SafeTanhTransform(D.TanhTransform):
9295
"""TanhTransform subclass that ensured that the transformation is numerically invertible."""
9396

9497
def _call(self, x: torch.Tensor) -> torch.Tensor:
95-
if x.dtype.is_floating_point:
96-
eps = torch.finfo(x.dtype).resolution
97-
else:
98-
raise NotImplementedError(f"No tanh transform for {x.dtype} inputs.")
99-
return safetanh(x, eps)
98+
return safetanh_noeps(x)
10099

101100
def _inverse(self, y: torch.Tensor) -> torch.Tensor:
102-
if y.dtype.is_floating_point:
103-
eps = torch.finfo(y.dtype).resolution
104-
else:
105-
raise NotImplementedError(f"No inverse tanh for {y.dtype} inputs.")
106-
x = safeatanh(y, eps)
107-
return x
101+
return safeatanh_noeps(y)
102+
103+
@property
104+
def inv(self):
105+
inv = None
106+
if self._inv is not None:
107+
inv = self._inv()
108+
if inv is None:
109+
inv = _InverseTransform(self)
110+
if not torch.compiler.is_dynamo_compiling():
111+
self._inv = weakref.ref(inv)
112+
return inv
108113

109114

110115
class NormalParamWrapper(nn.Module):
@@ -316,6 +321,33 @@ def log_prob(self, value, **kwargs):
316321
return lp
317322

318323

324+
class _PatchedComposeTransform(D.ComposeTransform):
325+
@property
326+
def inv(self):
327+
inv = None
328+
if self._inv is not None:
329+
inv = self._inv()
330+
if inv is None:
331+
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
332+
if not torch.compiler.is_dynamo_compiling():
333+
self._inv = weakref.ref(inv)
334+
inv._inv = weakref.ref(self)
335+
return inv
336+
337+
338+
class _PatchedAffineTransform(D.AffineTransform):
339+
@property
340+
def inv(self):
341+
inv = None
342+
if self._inv is not None:
343+
inv = self._inv()
344+
if inv is None:
345+
inv = _InverseTransform(self)
346+
if not torch.compiler.is_dynamo_compiling():
347+
self._inv = weakref.ref(inv)
348+
return inv
349+
350+
319351
class TanhNormal(FasterTransformedDistribution):
320352
"""Implements a TanhNormal distribution with location scaling.
321353
@@ -344,6 +376,8 @@ class TanhNormal(FasterTransformedDistribution):
344376
as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc.
345377
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
346378
value is kept. Default is ``False``;
379+
safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
380+
This will currently break with :func:`torch.compile`.
347381
"""
348382

349383
arg_constraints = {
@@ -369,6 +403,7 @@ def __init__(
369403
high: Union[torch.Tensor, Number] = 1.0,
370404
event_dims: int | None = None,
371405
tanh_loc: bool = False,
406+
safe_tanh: bool = True,
372407
**kwargs,
373408
):
374409
if "max" in kwargs:
@@ -419,13 +454,22 @@ def __init__(
419454
self.low = low
420455
self.high = high
421456

422-
t = SafeTanhTransform()
457+
if safe_tanh:
458+
if torch.compiler.is_dynamo_compiling():
459+
_err_compile_safetanh()
460+
t = SafeTanhTransform()
461+
else:
462+
t = D.TanhTransform()
423463
# t = D.TanhTransform()
424-
if self.non_trivial_max or self.non_trivial_min:
425-
t = D.ComposeTransform(
464+
if torch.compiler.is_dynamo_compiling() or (
465+
self.non_trivial_max or self.non_trivial_min
466+
):
467+
t = _PatchedComposeTransform(
426468
[
427469
t,
428-
D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2),
470+
_PatchedAffineTransform(
471+
loc=(high + low) / 2, scale=(high - low) / 2
472+
),
429473
]
430474
)
431475
self._t = t
@@ -446,7 +490,9 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
446490
if self.tanh_loc:
447491
loc = (loc / self.upscale).tanh() * self.upscale
448492
# loc must be rescaled if tanh_loc
449-
if self.non_trivial_max or self.non_trivial_min:
493+
if torch.compiler.is_dynamo_compiling() or (
494+
self.non_trivial_max or self.non_trivial_min
495+
):
450496
loc = loc + (self.high - self.low) / 2 + self.low
451497
self.loc = loc
452498
self.scale = scale
@@ -466,6 +512,10 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
466512
base = D.Normal(self.loc, self.scale)
467513
super().__init__(base, self._t)
468514

515+
@property
516+
def support(self):
517+
return D.constraints.real()
518+
469519
@property
470520
def root_dist(self):
471521
bd = self
@@ -696,10 +746,10 @@ def __init__(
696746
loc = self.update(param)
697747

698748
if self.non_trivial:
699-
t = D.ComposeTransform(
749+
t = _PatchedComposeTransform(
700750
[
701751
t,
702-
D.AffineTransform(
752+
_PatchedAffineTransform(
703753
loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2
704754
),
705755
]
@@ -761,3 +811,16 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
761811

762812

763813
uniform_sample_delta = _uniform_sample_delta
814+
815+
816+
def _err_compile_safetanh():
817+
raise RuntimeError(
818+
"safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass"
819+
"safe_tanh=False. "
820+
"If you are using a ProbabilisticTensorDictModule, this can be done via "
821+
"`distribution_kwargs={'safe_tanh': False}`. "
822+
"See https://github.com/pytorch/pytorch/issues/133529 for more details."
823+
)
824+
825+
826+
_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh)

0 commit comments

Comments
 (0)