Skip to content

[BUG] dtype in _split_and_pad_sequence is established using incorrect dim #2758

Closed
@KubaMichalczyk

Description

@KubaMichalczyk

In torchrl/objectives/value/utils.py#L287-L290, the code currently uses tensor.shape[-2] to choose between torch.int16 and torch.int32. This should use the time dimension (time_dim) instead, especially since it's used in _fast_td_lambda_return_estimate where the inputs are transposed first. As a result, the condition may be incorrect for certain input shapes (when the size of time_dim exceeds the int16 range, but F dimension of _fast_td_lambda_return_estimate inputs is within this range).

Code Reference:

    # int16 supports length up to 32767
    dtype = (
        torch.int16 if tensor.shape[-2] < torch.iinfo(torch.int16).max else torch.int32
    )

Proposed Fix:

    dtype = (
        torch.int16 if tensor.size(time_dim) < torch.iinfo(torch.int16).max else torch.int32
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions