5
5
from __future__ import annotations
6
6
7
7
import warnings
8
+ import weakref
8
9
from numbers import Number
9
10
from typing import Dict , Optional , Sequence , Tuple , Union
10
11
11
12
import numpy as np
12
13
import torch
13
14
from torch import distributions as D , nn
15
+ from torch .compiler import assume_constant_result
14
16
from torch .distributions import constraints
17
+ from torch .distributions .transforms import _InverseTransform
15
18
16
19
from torchrl .modules .distributions .truncated_normal import (
17
20
TruncatedNormal as _TruncatedNormal ,
20
23
from torchrl .modules .distributions .utils import (
21
24
_cast_device ,
22
25
FasterTransformedDistribution ,
23
- safeatanh ,
24
- safetanh ,
26
+ safeatanh_noeps ,
27
+ safetanh_noeps ,
25
28
)
26
29
from torchrl .modules .utils import mappings
27
30
@@ -92,19 +95,21 @@ class SafeTanhTransform(D.TanhTransform):
92
95
"""TanhTransform subclass that ensured that the transformation is numerically invertible."""
93
96
94
97
def _call (self , x : torch .Tensor ) -> torch .Tensor :
95
- if x .dtype .is_floating_point :
96
- eps = torch .finfo (x .dtype ).resolution
97
- else :
98
- raise NotImplementedError (f"No tanh transform for { x .dtype } inputs." )
99
- return safetanh (x , eps )
98
+ return safetanh_noeps (x )
100
99
101
100
def _inverse (self , y : torch .Tensor ) -> torch .Tensor :
102
- if y .dtype .is_floating_point :
103
- eps = torch .finfo (y .dtype ).resolution
104
- else :
105
- raise NotImplementedError (f"No inverse tanh for { y .dtype } inputs." )
106
- x = safeatanh (y , eps )
107
- return x
101
+ return safeatanh_noeps (y )
102
+
103
+ @property
104
+ def inv (self ):
105
+ inv = None
106
+ if self ._inv is not None :
107
+ inv = self ._inv ()
108
+ if inv is None :
109
+ inv = _InverseTransform (self )
110
+ if not torch .compiler .is_dynamo_compiling ():
111
+ self ._inv = weakref .ref (inv )
112
+ return inv
108
113
109
114
110
115
class NormalParamWrapper (nn .Module ):
@@ -316,6 +321,33 @@ def log_prob(self, value, **kwargs):
316
321
return lp
317
322
318
323
324
+ class _PatchedComposeTransform (D .ComposeTransform ):
325
+ @property
326
+ def inv (self ):
327
+ inv = None
328
+ if self ._inv is not None :
329
+ inv = self ._inv ()
330
+ if inv is None :
331
+ inv = _PatchedComposeTransform ([p .inv for p in reversed (self .parts )])
332
+ if not torch .compiler .is_dynamo_compiling ():
333
+ self ._inv = weakref .ref (inv )
334
+ inv ._inv = weakref .ref (self )
335
+ return inv
336
+
337
+
338
+ class _PatchedAffineTransform (D .AffineTransform ):
339
+ @property
340
+ def inv (self ):
341
+ inv = None
342
+ if self ._inv is not None :
343
+ inv = self ._inv ()
344
+ if inv is None :
345
+ inv = _InverseTransform (self )
346
+ if not torch .compiler .is_dynamo_compiling ():
347
+ self ._inv = weakref .ref (inv )
348
+ return inv
349
+
350
+
319
351
class TanhNormal (FasterTransformedDistribution ):
320
352
"""Implements a TanhNormal distribution with location scaling.
321
353
@@ -344,6 +376,8 @@ class TanhNormal(FasterTransformedDistribution):
344
376
as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc.
345
377
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
346
378
value is kept. Default is ``False``;
379
+ safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
380
+ This will currently break with :func:`torch.compile`.
347
381
"""
348
382
349
383
arg_constraints = {
@@ -369,6 +403,7 @@ def __init__(
369
403
high : Union [torch .Tensor , Number ] = 1.0 ,
370
404
event_dims : int | None = None ,
371
405
tanh_loc : bool = False ,
406
+ safe_tanh : bool = True ,
372
407
** kwargs ,
373
408
):
374
409
if "max" in kwargs :
@@ -419,13 +454,22 @@ def __init__(
419
454
self .low = low
420
455
self .high = high
421
456
422
- t = SafeTanhTransform ()
457
+ if safe_tanh :
458
+ if torch .compiler .is_dynamo_compiling ():
459
+ _err_compile_safetanh ()
460
+ t = SafeTanhTransform ()
461
+ else :
462
+ t = D .TanhTransform ()
423
463
# t = D.TanhTransform()
424
- if self .non_trivial_max or self .non_trivial_min :
425
- t = D .ComposeTransform (
464
+ if torch .compiler .is_dynamo_compiling () or (
465
+ self .non_trivial_max or self .non_trivial_min
466
+ ):
467
+ t = _PatchedComposeTransform (
426
468
[
427
469
t ,
428
- D .AffineTransform (loc = (high + low ) / 2 , scale = (high - low ) / 2 ),
470
+ _PatchedAffineTransform (
471
+ loc = (high + low ) / 2 , scale = (high - low ) / 2
472
+ ),
429
473
]
430
474
)
431
475
self ._t = t
@@ -446,7 +490,9 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
446
490
if self .tanh_loc :
447
491
loc = (loc / self .upscale ).tanh () * self .upscale
448
492
# loc must be rescaled if tanh_loc
449
- if self .non_trivial_max or self .non_trivial_min :
493
+ if torch .compiler .is_dynamo_compiling () or (
494
+ self .non_trivial_max or self .non_trivial_min
495
+ ):
450
496
loc = loc + (self .high - self .low ) / 2 + self .low
451
497
self .loc = loc
452
498
self .scale = scale
@@ -466,6 +512,10 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
466
512
base = D .Normal (self .loc , self .scale )
467
513
super ().__init__ (base , self ._t )
468
514
515
+ @property
516
+ def support (self ):
517
+ return D .constraints .real ()
518
+
469
519
@property
470
520
def root_dist (self ):
471
521
bd = self
@@ -696,10 +746,10 @@ def __init__(
696
746
loc = self .update (param )
697
747
698
748
if self .non_trivial :
699
- t = D . ComposeTransform (
749
+ t = _PatchedComposeTransform (
700
750
[
701
751
t ,
702
- D . AffineTransform (
752
+ _PatchedAffineTransform (
703
753
loc = (self .high + self .low ) / 2 , scale = (self .high - self .low ) / 2
704
754
),
705
755
]
@@ -761,3 +811,16 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
761
811
762
812
763
813
uniform_sample_delta = _uniform_sample_delta
814
+
815
+
816
+ def _err_compile_safetanh ():
817
+ raise RuntimeError (
818
+ "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass"
819
+ "safe_tanh=False. "
820
+ "If you are using a ProbabilisticTensorDictModule, this can be done via "
821
+ "`distribution_kwargs={'safe_tanh': False}`. "
822
+ "See https://github.com/pytorch/pytorch/issues/133529 for more details."
823
+ )
824
+
825
+
826
+ _warn_compile_safetanh = assume_constant_result (_err_compile_safetanh )
0 commit comments