Closed
Description
In torchrl/objectives/value/utils.py#L287-L290, the code currently uses tensor.shape[-2]
to choose between torch.int16
and torch.int32
. This should use the time dimension (time_dim
) instead, especially since it's used in _fast_td_lambda_return_estimate
where the inputs are transposed first. As a result, the condition may be incorrect for certain input shapes (when the size of time_dim
exceeds the int16
range, but F
dimension of _fast_td_lambda_return_estimate
inputs is within this range).
Code Reference:
# int16 supports length up to 32767
dtype = (
torch.int16 if tensor.shape[-2] < torch.iinfo(torch.int16).max else torch.int32
)
Proposed Fix:
dtype = (
torch.int16 if tensor.size(time_dim) < torch.iinfo(torch.int16).max else torch.int32
)
Metadata
Metadata
Assignees
Labels
No labels