Skip to content

Commit 94d44e5

Browse files
committed
[Feature] Make benchmarked losses compatible with torch.compile
ghstack-source-id: 699a6bb Pull Request resolved: #2405
1 parent e82a69f commit 94d44e5

File tree

10 files changed

+465
-142
lines changed

10 files changed

+465
-142
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 172 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tensordict import TensorDict
1111
from tensordict.nn import (
12+
InteractionType,
1213
NormalParamExtractor,
1314
ProbabilisticTensorDictModule as ProbMod,
1415
ProbabilisticTensorDictSequential as ProbSeq,
@@ -137,7 +138,10 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
137138
)
138139

139140

140-
def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
141+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
142+
def test_dqn_speed(
143+
benchmark, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
144+
):
141145
net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells)
142146
action_space = "one-hot"
143147
mod = QValueActor(net, in_keys=["obs"], action_space=action_space)
@@ -155,10 +159,23 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
155159
[batch],
156160
)
157161
loss(td)
162+
163+
if compile:
164+
if isinstance(compile, str):
165+
loss = torch.compile(loss, mode=compile, fullgraph=True)
166+
else:
167+
loss = torch.compile(loss, fullgraph=True)
168+
169+
loss(td)
170+
loss(td)
171+
158172
benchmark(loss, td)
159173

160174

161-
def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
175+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
176+
def test_ddpg_speed(
177+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
178+
):
162179
common = MLP(
163180
num_cells=ncells,
164181
in_features=n_obs,
@@ -200,10 +217,23 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
200217
loss = DDPGLoss(actor, value)
201218

202219
loss(td)
220+
221+
if compile:
222+
if isinstance(compile, str):
223+
loss = torch.compile(loss, mode=compile, fullgraph=True)
224+
else:
225+
loss = torch.compile(loss, fullgraph=True)
226+
227+
loss(td)
228+
loss(td)
229+
203230
benchmark(loss, td)
204231

205232

206-
def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
233+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
234+
def test_sac_speed(
235+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
236+
):
207237
common = MLP(
208238
num_cells=ncells,
209239
in_features=n_obs,
@@ -245,6 +275,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
245275
in_keys=["loc", "scale"],
246276
out_keys=["action"],
247277
distribution_class=TanhNormal,
278+
distribution_kwargs={"safe_tanh": False},
248279
),
249280
)
250281
value_head = Mod(
@@ -256,10 +287,23 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
256287
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
257288

258289
loss(td)
290+
291+
if compile:
292+
if isinstance(compile, str):
293+
loss = torch.compile(loss, mode=compile, fullgraph=True)
294+
else:
295+
loss = torch.compile(loss, fullgraph=True)
296+
297+
loss(td)
298+
loss(td)
299+
259300
benchmark(loss, td)
260301

261302

262-
def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
303+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
304+
def test_redq_speed(
305+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
306+
):
263307
common = MLP(
264308
num_cells=ncells,
265309
in_features=n_obs,
@@ -313,11 +357,22 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
313357
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
314358

315359
loss(td)
360+
361+
if compile:
362+
if isinstance(compile, str):
363+
loss = torch.compile(loss, mode=compile, fullgraph=True)
364+
else:
365+
loss = torch.compile(loss, fullgraph=True)
366+
367+
loss(td)
368+
loss(td)
369+
316370
benchmark(loss, td)
317371

318372

373+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
319374
def test_redq_deprec_speed(
320-
benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
375+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
321376
):
322377
common = MLP(
323378
num_cells=ncells,
@@ -372,10 +427,23 @@ def test_redq_deprec_speed(
372427
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))
373428

374429
loss(td)
430+
431+
if compile:
432+
if isinstance(compile, str):
433+
loss = torch.compile(loss, mode=compile, fullgraph=True)
434+
else:
435+
loss = torch.compile(loss, fullgraph=True)
436+
437+
loss(td)
438+
loss(td)
439+
375440
benchmark(loss, td)
376441

377442

378-
def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
443+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
444+
def test_td3_speed(
445+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
446+
):
379447
common = MLP(
380448
num_cells=ncells,
381449
in_features=n_obs,
@@ -417,14 +485,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
417485
in_keys=["loc", "scale"],
418486
out_keys=["action"],
419487
distribution_class=TanhNormal,
488+
distribution_kwargs={"safe_tanh": False},
420489
return_log_prob=True,
490+
default_interaction_type=InteractionType.DETERMINISTIC,
421491
),
422492
)
423493
value_head = Mod(
424494
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
425495
)
426496
value = Seq(common, value_head)
427-
value(actor(td))
497+
value(actor(td.clone()))
498+
if compile:
499+
actor_c = torch.compile(actor.get_dist, fullgraph=True)
500+
actor_c(td)
501+
actor_c = torch.compile(actor, fullgraph=True)
502+
actor_c(td)
503+
value_c = torch.compile(value, fullgraph=True)
504+
value_c(td)
428505

429506
loss = TD3Loss(
430507
actor,
@@ -433,10 +510,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
433510
)
434511

435512
loss(td)
513+
514+
if compile:
515+
if isinstance(compile, str):
516+
loss = torch.compile(loss, mode=compile, fullgraph=True)
517+
else:
518+
loss = torch.compile(loss, fullgraph=True)
519+
520+
loss(td)
521+
loss(td)
522+
436523
benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)
437524

438525

439-
def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
526+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
527+
def test_cql_speed(
528+
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
529+
):
440530
common = MLP(
441531
num_cells=ncells,
442532
in_features=n_obs,
@@ -475,7 +565,10 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
475565
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
476566
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
477567
ProbMod(
478-
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
568+
in_keys=["loc", "scale"],
569+
out_keys=["action"],
570+
distribution_class=TanhNormal,
571+
distribution_kwargs={"safe_tanh": False},
479572
),
480573
)
481574
value_head = Mod(
@@ -487,11 +580,22 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
487580
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
488581

489582
loss(td)
583+
584+
if compile:
585+
if isinstance(compile, str):
586+
loss = torch.compile(loss, mode=compile, fullgraph=True)
587+
else:
588+
loss = torch.compile(loss, fullgraph=True)
589+
590+
loss(td)
591+
loss(td)
592+
490593
benchmark(loss, td)
491594

492595

596+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
493597
def test_a2c_speed(
494-
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
598+
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
495599
):
496600
common_net = MLP(
497601
num_cells=ncells,
@@ -533,7 +637,10 @@ def test_a2c_speed(
533637
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
534638
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
535639
ProbMod(
536-
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
640+
in_keys=["loc", "scale"],
641+
out_keys=["action"],
642+
distribution_class=TanhNormal,
643+
distribution_kwargs={"safe_tanh": False},
537644
),
538645
)
539646
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -544,11 +651,22 @@ def test_a2c_speed(
544651
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
545652
advantage(td)
546653
loss(td)
654+
655+
if compile:
656+
if isinstance(compile, str):
657+
loss = torch.compile(loss, mode=compile, fullgraph=True)
658+
else:
659+
loss = torch.compile(loss, fullgraph=True)
660+
661+
loss(td)
662+
loss(td)
663+
547664
benchmark(loss, td)
548665

549666

667+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
550668
def test_ppo_speed(
551-
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
669+
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
552670
):
553671
common_net = MLP(
554672
num_cells=ncells,
@@ -590,7 +708,10 @@ def test_ppo_speed(
590708
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
591709
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
592710
ProbMod(
593-
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
711+
in_keys=["loc", "scale"],
712+
out_keys=["action"],
713+
distribution_class=TanhNormal,
714+
distribution_kwargs={"safe_tanh": False},
594715
),
595716
)
596717
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -601,11 +722,22 @@ def test_ppo_speed(
601722
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
602723
advantage(td)
603724
loss(td)
725+
726+
if compile:
727+
if isinstance(compile, str):
728+
loss = torch.compile(loss, mode=compile, fullgraph=True)
729+
else:
730+
loss = torch.compile(loss, fullgraph=True)
731+
732+
loss(td)
733+
loss(td)
734+
604735
benchmark(loss, td)
605736

606737

738+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
607739
def test_reinforce_speed(
608-
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
740+
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
609741
):
610742
common_net = MLP(
611743
num_cells=ncells,
@@ -647,7 +779,10 @@ def test_reinforce_speed(
647779
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
648780
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
649781
ProbMod(
650-
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
782+
in_keys=["loc", "scale"],
783+
out_keys=["action"],
784+
distribution_class=TanhNormal,
785+
distribution_kwargs={"safe_tanh": False},
651786
),
652787
)
653788
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -658,11 +793,22 @@ def test_reinforce_speed(
658793
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
659794
advantage(td)
660795
loss(td)
796+
797+
if compile:
798+
if isinstance(compile, str):
799+
loss = torch.compile(loss, mode=compile, fullgraph=True)
800+
else:
801+
loss = torch.compile(loss, fullgraph=True)
802+
803+
loss(td)
804+
loss(td)
805+
661806
benchmark(loss, td)
662807

663808

809+
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
664810
def test_iql_speed(
665-
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
811+
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
666812
):
667813
common_net = MLP(
668814
num_cells=ncells,
@@ -723,6 +869,16 @@ def test_iql_speed(
723869

724870
loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue)
725871
loss(td)
872+
873+
if compile:
874+
if isinstance(compile, str):
875+
loss = torch.compile(loss, mode=compile, fullgraph=True)
876+
else:
877+
loss = torch.compile(loss, fullgraph=True)
878+
879+
loss(td)
880+
loss(td)
881+
726882
benchmark(loss, td)
727883

728884

torchrl/envs/transforms/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@
3939
unravel_key,
4040
unravel_key_list,
4141
)
42-
from tensordict._C import _unravel_key_to_tuple
4342
from tensordict.nn import dispatch, TensorDictModuleBase
44-
from tensordict.utils import expand_as_right, expand_right, NestedKey
43+
from tensordict.utils import (
44+
_unravel_key_to_tuple,
45+
expand_as_right,
46+
expand_right,
47+
NestedKey,
48+
)
4549
from torch import nn, Tensor
4650
from torch.utils._pytree import tree_map
4751

0 commit comments

Comments
 (0)