Skip to content

Commit 9ef0981

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 5ec5981 commit 9ef0981

File tree

4 files changed

+4
-8
lines changed

4 files changed

+4
-8
lines changed

test/test_transforms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14590,7 +14590,6 @@ def test_trans_serial_env_check(self):
1459014590
base_env = SerialEnv(
1459114591
3,
1459214592
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14593-
batch_locked=False,
1459414593
)
1459514594
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
1459614595
policy_odd = self._create_policy_odd(base_env)
@@ -14604,7 +14603,6 @@ def test_trans_parallel_env_check(self):
1460414603
base_env = ParallelEnv(
1460514604
3,
1460614605
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14607-
batch_locked=False,
1460814606
mp_start_method=mp_ctx,
1460914607
)
1461014608
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)

torchrl/envs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
CenterCrop,
5959
ClipTransform,
6060
Compose,
61+
ConditionalPolicySwitch,
6162
ConditionalSkip,
6263
Crop,
6364
DeviceCastTransform,
@@ -137,6 +138,7 @@
137138
"AutoResetTransform",
138139
"AsyncEnvPool",
139140
"ProcessorAsyncEnvPool",
141+
"ConditionalPolicySwitch",
140142
"ThreadingAsyncEnvPool",
141143
"BatchSizeTransform",
142144
"BinarizeReward",

torchrl/envs/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ClipTransform,
2222
Compose,
23+
ConditionalPolicySwitch,
2324
ConditionalSkip,
2425
Crop,
2526
DeviceCastTransform,
@@ -83,6 +84,7 @@
8384
"CatFrames",
8485
"CatTensors",
8586
"CenterCrop",
87+
"ConditionalPolicySwitch",
8688
"ClipTransform",
8789
"Compose",
8890
"ConditionalSkip",

torchrl/envs/transforms/transforms.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11398,12 +11398,6 @@ def _reset(
1139811398
parent: TransformedEnv = self.parent
1139911399
reset_td_save = None
1140011400
if not cond.all():
11401-
if parent.base_env.batch_locked:
11402-
raise RuntimeError(
11403-
"Cannot run partial steps in a batched locked environment. "
11404-
"Hint: Parallel and Serial envs can be unlocked through a keyword argument in "
11405-
"the constructor."
11406-
)
1140711401
reset_td_save = tensordict_reset.copy()
1140811402
tensordict_reset = tensordict_reset[cond]
1140911403
tensordict = tensordict[cond]

0 commit comments

Comments
 (0)