9
9
10
10
from tensordict import TensorDict
11
11
from tensordict .nn import (
12
+ InteractionType ,
12
13
NormalParamExtractor ,
13
14
ProbabilisticTensorDictModule as ProbMod ,
14
15
ProbabilisticTensorDictSequential as ProbSeq ,
@@ -137,7 +138,10 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
137
138
)
138
139
139
140
140
- def test_dqn_speed (benchmark , n_obs = 8 , n_act = 4 , depth = 3 , ncells = 128 , batch = 128 ):
141
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
142
+ def test_dqn_speed (
143
+ benchmark , compile , n_obs = 8 , n_act = 4 , depth = 3 , ncells = 128 , batch = 128
144
+ ):
141
145
net = MLP (in_features = n_obs , out_features = n_act , depth = depth , num_cells = ncells )
142
146
action_space = "one-hot"
143
147
mod = QValueActor (net , in_keys = ["obs" ], action_space = action_space )
@@ -155,10 +159,23 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
155
159
[batch ],
156
160
)
157
161
loss (td )
162
+
163
+ if compile :
164
+ if isinstance (compile , str ):
165
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
166
+ else :
167
+ loss = torch .compile (loss , fullgraph = True )
168
+
169
+ loss (td )
170
+ loss (td )
171
+
158
172
benchmark (loss , td )
159
173
160
174
161
- def test_ddpg_speed (benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64 ):
175
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
176
+ def test_ddpg_speed (
177
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
178
+ ):
162
179
common = MLP (
163
180
num_cells = ncells ,
164
181
in_features = n_obs ,
@@ -200,10 +217,23 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
200
217
loss = DDPGLoss (actor , value )
201
218
202
219
loss (td )
220
+
221
+ if compile :
222
+ if isinstance (compile , str ):
223
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
224
+ else :
225
+ loss = torch .compile (loss , fullgraph = True )
226
+
227
+ loss (td )
228
+ loss (td )
229
+
203
230
benchmark (loss , td )
204
231
205
232
206
- def test_sac_speed (benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64 ):
233
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
234
+ def test_sac_speed (
235
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
236
+ ):
207
237
common = MLP (
208
238
num_cells = ncells ,
209
239
in_features = n_obs ,
@@ -245,6 +275,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
245
275
in_keys = ["loc" , "scale" ],
246
276
out_keys = ["action" ],
247
277
distribution_class = TanhNormal ,
278
+ distribution_kwargs = {"safe_tanh" : False },
248
279
),
249
280
)
250
281
value_head = Mod (
@@ -256,10 +287,23 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
256
287
loss = SACLoss (actor , value , action_spec = Unbounded (shape = (n_act ,)))
257
288
258
289
loss (td )
290
+
291
+ if compile :
292
+ if isinstance (compile , str ):
293
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
294
+ else :
295
+ loss = torch .compile (loss , fullgraph = True )
296
+
297
+ loss (td )
298
+ loss (td )
299
+
259
300
benchmark (loss , td )
260
301
261
302
262
- def test_redq_speed (benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64 ):
303
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
304
+ def test_redq_speed (
305
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
306
+ ):
263
307
common = MLP (
264
308
num_cells = ncells ,
265
309
in_features = n_obs ,
@@ -313,11 +357,22 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
313
357
loss = REDQLoss (actor , value , action_spec = Unbounded (shape = (n_act ,)))
314
358
315
359
loss (td )
360
+
361
+ if compile :
362
+ if isinstance (compile , str ):
363
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
364
+ else :
365
+ loss = torch .compile (loss , fullgraph = True )
366
+
367
+ loss (td )
368
+ loss (td )
369
+
316
370
benchmark (loss , td )
317
371
318
372
373
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
319
374
def test_redq_deprec_speed (
320
- benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
375
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
321
376
):
322
377
common = MLP (
323
378
num_cells = ncells ,
@@ -372,10 +427,23 @@ def test_redq_deprec_speed(
372
427
loss = REDQLoss_deprecated (actor , value , action_spec = Unbounded (shape = (n_act ,)))
373
428
374
429
loss (td )
430
+
431
+ if compile :
432
+ if isinstance (compile , str ):
433
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
434
+ else :
435
+ loss = torch .compile (loss , fullgraph = True )
436
+
437
+ loss (td )
438
+ loss (td )
439
+
375
440
benchmark (loss , td )
376
441
377
442
378
- def test_td3_speed (benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64 ):
443
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
444
+ def test_td3_speed (
445
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
446
+ ):
379
447
common = MLP (
380
448
num_cells = ncells ,
381
449
in_features = n_obs ,
@@ -417,14 +485,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
417
485
in_keys = ["loc" , "scale" ],
418
486
out_keys = ["action" ],
419
487
distribution_class = TanhNormal ,
488
+ distribution_kwargs = {"safe_tanh" : False },
420
489
return_log_prob = True ,
490
+ default_interaction_type = InteractionType .DETERMINISTIC ,
421
491
),
422
492
)
423
493
value_head = Mod (
424
494
value , in_keys = ["hidden" , "action" ], out_keys = ["state_action_value" ]
425
495
)
426
496
value = Seq (common , value_head )
427
- value (actor (td ))
497
+ value (actor (td .clone ()))
498
+ if compile :
499
+ actor_c = torch .compile (actor .get_dist , fullgraph = True )
500
+ actor_c (td )
501
+ actor_c = torch .compile (actor , fullgraph = True )
502
+ actor_c (td )
503
+ value_c = torch .compile (value , fullgraph = True )
504
+ value_c (td )
428
505
429
506
loss = TD3Loss (
430
507
actor ,
@@ -433,10 +510,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
433
510
)
434
511
435
512
loss (td )
513
+
514
+ if compile :
515
+ if isinstance (compile , str ):
516
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
517
+ else :
518
+ loss = torch .compile (loss , fullgraph = True )
519
+
520
+ loss (td )
521
+ loss (td )
522
+
436
523
benchmark .pedantic (loss , args = (td ,), rounds = 100 , iterations = 10 )
437
524
438
525
439
- def test_cql_speed (benchmark , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64 ):
526
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
527
+ def test_cql_speed (
528
+ benchmark , compile , n_obs = 8 , n_act = 4 , ncells = 128 , batch = 128 , n_hidden = 64
529
+ ):
440
530
common = MLP (
441
531
num_cells = ncells ,
442
532
in_features = n_obs ,
@@ -475,7 +565,10 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
475
565
Mod (actor_net , in_keys = ["hidden" ], out_keys = ["param" ]),
476
566
Mod (NormalParamExtractor (), in_keys = ["param" ], out_keys = ["loc" , "scale" ]),
477
567
ProbMod (
478
- in_keys = ["loc" , "scale" ], out_keys = ["action" ], distribution_class = TanhNormal
568
+ in_keys = ["loc" , "scale" ],
569
+ out_keys = ["action" ],
570
+ distribution_class = TanhNormal ,
571
+ distribution_kwargs = {"safe_tanh" : False },
479
572
),
480
573
)
481
574
value_head = Mod (
@@ -487,11 +580,22 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
487
580
loss = CQLLoss (actor , value , action_spec = Unbounded (shape = (n_act ,)))
488
581
489
582
loss (td )
583
+
584
+ if compile :
585
+ if isinstance (compile , str ):
586
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
587
+ else :
588
+ loss = torch .compile (loss , fullgraph = True )
589
+
590
+ loss (td )
591
+ loss (td )
592
+
490
593
benchmark (loss , td )
491
594
492
595
596
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
493
597
def test_a2c_speed (
494
- benchmark , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
598
+ benchmark , compile , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
495
599
):
496
600
common_net = MLP (
497
601
num_cells = ncells ,
@@ -533,7 +637,10 @@ def test_a2c_speed(
533
637
Mod (actor_net , in_keys = ["hidden" ], out_keys = ["param" ]),
534
638
Mod (NormalParamExtractor (), in_keys = ["param" ], out_keys = ["loc" , "scale" ]),
535
639
ProbMod (
536
- in_keys = ["loc" , "scale" ], out_keys = ["action" ], distribution_class = TanhNormal
640
+ in_keys = ["loc" , "scale" ],
641
+ out_keys = ["action" ],
642
+ distribution_class = TanhNormal ,
643
+ distribution_kwargs = {"safe_tanh" : False },
537
644
),
538
645
)
539
646
critic = Seq (common , Mod (value_net , in_keys = ["hidden" ], out_keys = ["state_value" ]))
@@ -544,11 +651,22 @@ def test_a2c_speed(
544
651
advantage = GAE (value_network = critic , gamma = 0.99 , lmbda = 0.95 , shifted = True )
545
652
advantage (td )
546
653
loss (td )
654
+
655
+ if compile :
656
+ if isinstance (compile , str ):
657
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
658
+ else :
659
+ loss = torch .compile (loss , fullgraph = True )
660
+
661
+ loss (td )
662
+ loss (td )
663
+
547
664
benchmark (loss , td )
548
665
549
666
667
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
550
668
def test_ppo_speed (
551
- benchmark , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
669
+ benchmark , compile , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
552
670
):
553
671
common_net = MLP (
554
672
num_cells = ncells ,
@@ -590,7 +708,10 @@ def test_ppo_speed(
590
708
Mod (actor_net , in_keys = ["hidden" ], out_keys = ["param" ]),
591
709
Mod (NormalParamExtractor (), in_keys = ["param" ], out_keys = ["loc" , "scale" ]),
592
710
ProbMod (
593
- in_keys = ["loc" , "scale" ], out_keys = ["action" ], distribution_class = TanhNormal
711
+ in_keys = ["loc" , "scale" ],
712
+ out_keys = ["action" ],
713
+ distribution_class = TanhNormal ,
714
+ distribution_kwargs = {"safe_tanh" : False },
594
715
),
595
716
)
596
717
critic = Seq (common , Mod (value_net , in_keys = ["hidden" ], out_keys = ["state_value" ]))
@@ -601,11 +722,22 @@ def test_ppo_speed(
601
722
advantage = GAE (value_network = critic , gamma = 0.99 , lmbda = 0.95 , shifted = True )
602
723
advantage (td )
603
724
loss (td )
725
+
726
+ if compile :
727
+ if isinstance (compile , str ):
728
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
729
+ else :
730
+ loss = torch .compile (loss , fullgraph = True )
731
+
732
+ loss (td )
733
+ loss (td )
734
+
604
735
benchmark (loss , td )
605
736
606
737
738
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
607
739
def test_reinforce_speed (
608
- benchmark , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
740
+ benchmark , compile , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
609
741
):
610
742
common_net = MLP (
611
743
num_cells = ncells ,
@@ -647,7 +779,10 @@ def test_reinforce_speed(
647
779
Mod (actor_net , in_keys = ["hidden" ], out_keys = ["param" ]),
648
780
Mod (NormalParamExtractor (), in_keys = ["param" ], out_keys = ["loc" , "scale" ]),
649
781
ProbMod (
650
- in_keys = ["loc" , "scale" ], out_keys = ["action" ], distribution_class = TanhNormal
782
+ in_keys = ["loc" , "scale" ],
783
+ out_keys = ["action" ],
784
+ distribution_class = TanhNormal ,
785
+ distribution_kwargs = {"safe_tanh" : False },
651
786
),
652
787
)
653
788
critic = Seq (common , Mod (value_net , in_keys = ["hidden" ], out_keys = ["state_value" ]))
@@ -658,11 +793,22 @@ def test_reinforce_speed(
658
793
advantage = GAE (value_network = critic , gamma = 0.99 , lmbda = 0.95 , shifted = True )
659
794
advantage (td )
660
795
loss (td )
796
+
797
+ if compile :
798
+ if isinstance (compile , str ):
799
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
800
+ else :
801
+ loss = torch .compile (loss , fullgraph = True )
802
+
803
+ loss (td )
804
+ loss (td )
805
+
661
806
benchmark (loss , td )
662
807
663
808
809
+ @pytest .mark .parametrize ("compile" , [False , True , "reduce-overhead" ])
664
810
def test_iql_speed (
665
- benchmark , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
811
+ benchmark , compile , n_obs = 8 , n_act = 4 , n_hidden = 64 , ncells = 128 , batch = 128 , T = 10
666
812
):
667
813
common_net = MLP (
668
814
num_cells = ncells ,
@@ -723,6 +869,16 @@ def test_iql_speed(
723
869
724
870
loss = IQLLoss (actor_network = actor , value_network = value , qvalue_network = qvalue )
725
871
loss (td )
872
+
873
+ if compile :
874
+ if isinstance (compile , str ):
875
+ loss = torch .compile (loss , mode = compile , fullgraph = True )
876
+ else :
877
+ loss = torch .compile (loss , fullgraph = True )
878
+
879
+ loss (td )
880
+ loss (td )
881
+
726
882
benchmark (loss , td )
727
883
728
884
0 commit comments