Skip to content

Commit 8edc29c

Browse files
committed
[Feature] Deactivate vmap in objectives
ghstack-source-id: f37922f Pull-Request-resolved: #2957
1 parent d882ea2 commit 8edc29c

File tree

11 files changed

+623
-29
lines changed

11 files changed

+623
-29
lines changed

test/test_cost.py

Lines changed: 489 additions & 3 deletions
Large diffs are not rendered by default.

torchrl/objectives/cql.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class CQLLoss(LossModule):
9292
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
9393
``"mean"``: the sum of the output will be divided by the number of
9494
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
95+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
96+
Defaults to ``False``.
9597
9698
Examples:
9799
>>> import torch
@@ -290,6 +292,7 @@ def __init__(
290292
with_lagrange: bool = False,
291293
lagrange_thresh: float = 0.0,
292294
reduction: str = None,
295+
deactivate_vmap: bool = False,
293296
) -> None:
294297
self._out_keys = None
295298
if reduction is None:
@@ -303,6 +306,7 @@ def __init__(
303306
"actor_network",
304307
create_target_params=self.delay_actor,
305308
)
309+
self.deactivate_vmap = deactivate_vmap
306310

307311
# Q value
308312
self.delay_qvalue = delay_qvalue
@@ -376,10 +380,15 @@ def __init__(
376380

377381
def _make_vmap(self):
378382
self._vmap_qvalue_networkN0 = _vmap_func(
379-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
383+
self.qvalue_network,
384+
(None, 0),
385+
randomness=self.vmap_randomness,
386+
pseudo_vmap=self.deactivate_vmap,
380387
)
381388
self._vmap_qvalue_network00 = _vmap_func(
382-
self.qvalue_network, randomness=self.vmap_randomness
389+
self.qvalue_network,
390+
randomness=self.vmap_randomness,
391+
pseudo_vmap=self.deactivate_vmap,
383392
)
384393

385394
@property

torchrl/objectives/crossq.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class CrossQLoss(LossModule):
9292
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
9393
``"mean"``: the sum of the output will be divided by the number of
9494
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
95+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
96+
Defaults to ``False``.
9597
9698
Examples:
9799
>>> import torch
@@ -267,6 +269,7 @@ def __init__(
267269
priority_key: str = None,
268270
separate_losses: bool = False,
269271
reduction: str = None,
272+
deactivate_vmap: bool = False,
270273
) -> None:
271274
self._in_keys = None
272275
self._out_keys = None
@@ -275,6 +278,8 @@ def __init__(
275278
super().__init__()
276279
self._set_deprecated_ctor_keys(priority_key=priority_key)
277280

281+
self.deactivate_vmap = deactivate_vmap
282+
278283
# Actor
279284
self.convert_to_functional(
280285
actor_network,
@@ -344,7 +349,10 @@ def __init__(
344349

345350
def _make_vmap(self):
346351
self._vmap_qnetworkN0 = _vmap_func(
347-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
352+
self.qvalue_network,
353+
(None, 0),
354+
randomness=self.vmap_randomness,
355+
pseudo_vmap=self.deactivate_vmap,
348356
)
349357

350358
@property

torchrl/objectives/deprecated.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class REDQLoss_deprecated(LossModule):
8686
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
8787
``"mean"``: the sum of the output will be divided by the number of
8888
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
89+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
90+
Defaults to ``False``.
8991
"""
9092

9193
@dataclass
@@ -164,6 +166,7 @@ def __init__(
164166
priority_key: str = None,
165167
separate_losses: bool = False,
166168
reduction: str = None,
169+
deactivate_vmap: bool = False,
167170
):
168171
self._in_keys = None
169172
self._out_keys = None
@@ -172,6 +175,8 @@ def __init__(
172175
super().__init__()
173176
self._set_deprecated_ctor_keys(priority_key=priority_key)
174177

178+
self.deactivate_vmap = deactivate_vmap
179+
175180
self.convert_to_functional(
176181
actor_network,
177182
"actor_network",
@@ -234,7 +239,9 @@ def __init__(
234239
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
235240

236241
def _make_vmap(self):
237-
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
242+
self._vmap_qvalue_networkN0 = _vmap_func(
243+
self.qvalue_network, (None, 0), pseudo_vmap=self.deactivate_vmap
244+
)
238245

239246
@property
240247
def target_entropy(self):

torchrl/objectives/iql.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchrl.objectives.common import LossModule
2020
from torchrl.objectives.utils import (
2121
_GAMMA_LMBDA_DEPREC_ERROR,
22+
_pseudo_vmap,
2223
_reduce,
2324
_vmap_func,
2425
default_value_kwargs,
@@ -68,6 +69,8 @@ class IQLLoss(LossModule):
6869
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
6970
``"mean"``: the sum of the output will be divided by the number of
7071
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
72+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
73+
Defaults to ``False``.
7174
7275
Examples:
7376
>>> import torch
@@ -266,6 +269,7 @@ def __init__(
266269
priority_key: str = None,
267270
separate_losses: bool = False,
268271
reduction: str = None,
272+
deactivate_vmap: bool = False,
269273
) -> None:
270274
self._in_keys = None
271275
self._out_keys = None
@@ -274,6 +278,8 @@ def __init__(
274278
super().__init__()
275279
self._set_deprecated_ctor_keys(priority=priority_key)
276280

281+
self.deactivate_vmap = deactivate_vmap
282+
277283
# IQL parameter
278284
self.temperature = temperature
279285
self.expectile = expectile
@@ -323,7 +329,10 @@ def __init__(
323329

324330
def _make_vmap(self):
325331
self._vmap_qvalue_networkN0 = _vmap_func(
326-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
332+
self.qvalue_network,
333+
(None, 0),
334+
randomness=self.vmap_randomness,
335+
pseudo_vmap=self.deactivate_vmap,
327336
)
328337

329338
@property
@@ -824,7 +833,11 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
824833
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
825834
# unsqueeze the action if it lacks on trailing singleton dim
826835
action = action.unsqueeze(-1)
827-
chosen_state_action_value = torch.vmap(
836+
if self.deactivate_vmap:
837+
vmap = _pseudo_vmap
838+
else:
839+
vmap = torch.vmap
840+
chosen_state_action_value = vmap(
828841
lambda state_action_value, action: torch.gather(
829842
state_action_value, -1, index=action
830843
).squeeze(-1),
@@ -883,7 +896,11 @@ def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
883896
):
884897
# unsqueeze the action if it lacks on trailing singleton dim
885898
action = action.unsqueeze(-1)
886-
chosen_state_action_value = torch.vmap(
899+
if self.deactivate_vmap:
900+
vmap = _pseudo_vmap
901+
else:
902+
vmap = torch.vmap
903+
chosen_state_action_value = vmap(
887904
lambda state_action_value, action: torch.gather(
888905
state_action_value, -1, index=action
889906
).squeeze(-1),
@@ -932,7 +949,11 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
932949
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
933950
# unsqueeze the action if it lacks on trailing singleton dim
934951
action = action.unsqueeze(-1)
935-
pred_val = torch.vmap(
952+
if self.deactivate_vmap:
953+
vmap = _pseudo_vmap
954+
else:
955+
vmap = torch.vmap
956+
pred_val = vmap(
936957
lambda state_action_value, action: torch.gather(
937958
state_action_value, -1, index=action
938959
).squeeze(-1),

torchrl/objectives/redq.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class REDQLoss(LossModule):
8686
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
8787
``"mean"``: the sum of the output will be divided by the number of
8888
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
89+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
90+
Defaults to ``False``.
8991
9092
Examples:
9193
>>> import torch
@@ -280,6 +282,7 @@ def __init__(
280282
priority_key: str = None,
281283
separate_losses: bool = False,
282284
reduction: str = None,
285+
deactivate_vmap: bool = False,
283286
):
284287
if reduction is None:
285288
reduction = "mean"
@@ -295,6 +298,7 @@ def __init__(
295298

296299
# let's make sure that actor_network has `return_log_prob` set to True
297300
self.actor_network.return_log_prob = True
301+
self.deactivate_vmap = deactivate_vmap
298302
if separate_losses:
299303
# we want to make sure there are no duplicates in the params: the
300304
# params of critic must be refs to actor if they're shared
@@ -351,10 +355,15 @@ def __init__(
351355

352356
def _make_vmap(self):
353357
self._vmap_qvalue_network00 = _vmap_func(
354-
self.qvalue_network, randomness=self.vmap_randomness
358+
self.qvalue_network,
359+
randomness=self.vmap_randomness,
360+
pseudo_vmap=self.deactivate_vmap,
355361
)
356362
self._vmap_getdist = _vmap_func(
357-
self.actor_network, func="get_dist_params", randomness=self.vmap_randomness
363+
self.actor_network,
364+
func="get_dist_params",
365+
randomness=self.vmap_randomness,
366+
pseudo_vmap=self.deactivate_vmap,
358367
)
359368

360369
@property

torchrl/objectives/sac.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class SACLoss(LossModule):
130130
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
131131
shape of the data and that masking the data results in a valid data structure. Among other things, this may
132132
not be true in MARL settings or when using RNNs. Defaults to ``False``.
133+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
134+
Defaults to ``False``.
133135
134136
Examples:
135137
>>> import torch
@@ -334,6 +336,7 @@ def __init__(
334336
separate_losses: bool = False,
335337
reduction: str = None,
336338
skip_done_states: bool = False,
339+
deactivate_vmap: bool = False,
337340
) -> None:
338341
self._in_keys = None
339342
self._out_keys = None
@@ -344,6 +347,7 @@ def __init__(
344347

345348
# Actor
346349
self.delay_actor = delay_actor
350+
self.deactivate_vmap = deactivate_vmap
347351
self.convert_to_functional(
348352
actor_network,
349353
"actor_network",
@@ -445,11 +449,16 @@ def __init__(
445449

446450
def _make_vmap(self):
447451
self._vmap_qnetworkN0 = _vmap_func(
448-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
452+
self.qvalue_network,
453+
(None, 0),
454+
randomness=self.vmap_randomness,
455+
pseudo_vmap=self.deactivate_vmap,
449456
)
450457
if self._version == 1:
451458
self._vmap_qnetwork00 = _vmap_func(
452-
self.qvalue_network, randomness=self.vmap_randomness
459+
self.qvalue_network,
460+
randomness=self.vmap_randomness,
461+
pseudo_vmap=self.deactivate_vmap,
453462
)
454463

455464
@property
@@ -527,11 +536,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
527536
self._value_estimator = TD1Estimator(
528537
**hp,
529538
value_network=value_net,
539+
deactivate_vmap=self.deactivate_vmap,
530540
)
531541
elif value_type is ValueEstimators.TD0:
532542
self._value_estimator = TD0Estimator(
533543
**hp,
534544
value_network=value_net,
545+
deactivate_vmap=self.deactivate_vmap,
535546
)
536547
elif value_type is ValueEstimators.GAE:
537548
raise NotImplementedError(
@@ -541,6 +552,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
541552
self._value_estimator = TDLambdaEstimator(
542553
**hp,
543554
value_network=value_net,
555+
deactivate_vmap=self.deactivate_vmap,
544556
)
545557
else:
546558
raise NotImplementedError(f"Unknown value type {value_type}")
@@ -673,7 +685,6 @@ def _actor_loss(
673685
raise RuntimeError(
674686
f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
675687
)
676-
677688
return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()}
678689

679690
@property
@@ -922,6 +933,8 @@ class DiscreteSACLoss(LossModule):
922933
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
923934
shape of the data and that masking the data results in a valid data structure. Among other things, this may
924935
not be true in MARL settings or when using RNNs. Defaults to ``False``.
936+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
937+
Defaults to ``False``.
925938
926939
Examples:
927940
>>> import torch
@@ -1098,6 +1111,7 @@ def __init__(
10981111
separate_losses: bool = False,
10991112
reduction: str = None,
11001113
skip_done_states: bool = False,
1114+
deactivate_vmap: bool = False,
11011115
):
11021116
if reduction is None:
11031117
reduction = "mean"
@@ -1110,6 +1124,7 @@ def __init__(
11101124
"actor_network",
11111125
create_target_params=self.delay_actor,
11121126
)
1127+
self.deactivate_vmap = deactivate_vmap
11131128
if separate_losses:
11141129
# we want to make sure there are no duplicates in the params: the
11151130
# params of critic must be refs to actor if they're shared
@@ -1184,7 +1199,10 @@ def __init__(
11841199

11851200
def _make_vmap(self):
11861201
self._vmap_qnetworkN0 = _vmap_func(
1187-
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
1202+
self.qvalue_network,
1203+
(None, 0),
1204+
randomness=self.vmap_randomness,
1205+
pseudo_vmap=self.deactivate_vmap,
11881206
)
11891207

11901208
def _forward_value_estimator_keys(self, **kwargs) -> None:
@@ -1436,11 +1454,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
14361454
self._value_estimator = TD1Estimator(
14371455
**hp,
14381456
value_network=None,
1457+
deactivate_vmap=self.deactivate_vmap,
14391458
)
14401459
elif value_type is ValueEstimators.TD0:
14411460
self._value_estimator = TD0Estimator(
14421461
**hp,
14431462
value_network=None,
1463+
deactivate_vmap=self.deactivate_vmap,
14441464
)
14451465
elif value_type is ValueEstimators.GAE:
14461466
raise NotImplementedError(
@@ -1450,6 +1470,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
14501470
self._value_estimator = TDLambdaEstimator(
14511471
**hp,
14521472
value_network=None,
1473+
deactivate_vmap=self.deactivate_vmap,
14531474
)
14541475
else:
14551476
raise NotImplementedError(f"Unknown value type {value_type}")

0 commit comments

Comments
 (0)