Skip to content

Commit cfaec7e

Browse files
committed
[Refactor] Use default device instead of CPU in losses
ghstack-source-id: dfcb987 Pull Request resolved: #2687
1 parent a09fada commit cfaec7e

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

torchrl/objectives/cql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(
323323
try:
324324
device = next(self.parameters()).device
325325
except AttributeError:
326-
device = torch.device("cpu")
326+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
327327
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
328328
if bool(min_alpha) ^ bool(max_alpha):
329329
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/crossq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __init__(
306306
try:
307307
device = next(self.parameters()).device
308308
except AttributeError:
309-
device = torch.device("cpu")
309+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
310310
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
311311
if bool(min_alpha) ^ bool(max_alpha):
312312
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/decision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
try:
104104
device = next(self.parameters()).device
105105
except AttributeError:
106-
device = torch.device("cpu")
106+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
107107

108108
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
109109
if bool(min_alpha) ^ bool(max_alpha):

torchrl/objectives/deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
try:
196196
device = next(self.parameters()).device
197197
except AttributeError:
198-
device = torch.device("cpu")
198+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
199199

200200
self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
201201
self.register_buffer(

torchrl/objectives/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def __init__(
376376
try:
377377
device = next(self.parameters()).device
378378
except (AttributeError, StopIteration):
379-
device = torch.device("cpu")
379+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
380380

381381
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
382382
if critic_coef is not None:

torchrl/objectives/redq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
try:
310310
device = next(self.parameters()).device
311311
except AttributeError:
312-
device = torch.device("cpu")
312+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
313313

314314
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
315315
self.register_buffer(

torchrl/objectives/sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def __init__(
383383
try:
384384
device = next(self.parameters()).device
385385
except AttributeError:
386-
device = torch.device("cpu")
386+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
387387
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
388388
if bool(min_alpha) ^ bool(max_alpha):
389389
min_alpha = min_alpha if min_alpha else 0.0
@@ -1102,7 +1102,7 @@ def __init__(
11021102
try:
11031103
device = next(self.parameters()).device
11041104
except AttributeError:
1105-
device = torch.device("cpu")
1105+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
11061106

11071107
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
11081108
if bool(min_alpha) ^ bool(max_alpha):

0 commit comments

Comments
 (0)