Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Feb 13, 2025
1 parent e8442d3 commit c1d894f
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class Transform(nn.Module):
"""

invertible = False
disable_inv_on_reset = False

def __init__(
self,
Expand Down Expand Up @@ -1020,8 +1021,9 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
)
# Inputs might be transformed, so need to apply inverse transform
# before passing to the env reset function.
with _set_missing_tolerance(self.transform, True):
tensordict = self.transform.inv(tensordict)
if not self.transform.disable_inv_on_reset:
with _set_missing_tolerance(self.transform, True):
tensordict = self.transform.inv(tensordict)
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
if tensordict is None:
# make sure all transforms see a source tensordict
Expand Down Expand Up @@ -1264,6 +1266,9 @@ def map_transform(trsf):
self.transforms = nn.ModuleList(transforms)
for t in transforms:
t.set_container(self)
self.disable_inv_on_reset = any(
[trsf.disable_inv_on_reset for trsf in transforms]
)

def to(self, *args, **kwargs):
# because Module.to(...) does not call to(...) on sub-modules, we have
Expand Down Expand Up @@ -10437,6 +10442,8 @@ class MultiAction(Transform):
"""

disable_inv_on_reset = True

def __init__(
self,
*,
Expand Down

0 comments on commit c1d894f

Please sign in to comment.