Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsemax fails with torch.use_deterministic_algorithms(True) | deterministic CUDA cumsum #38

Open
weberBen opened this issue Dec 18, 2024 · 1 comment

Comments

@weberBen
Copy link

Problem

When using torch.use_deterministic_algorithms(True), the sparsemax function fails due to the lack of deterministic support for cumsum in CUDA. This issue occurs specifically in the _sparsemax_threshold_and_support function, where the operation:

topk_cumsum = topk.cumsum(dim) - 1

triggers the following error:

RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

Without a deterministic support models trained on GPU with the use of sparmax have unexcepted behavior when running on CPU for inference (with a radical drop in prediction accuracy).

Steps to reproduce

import torch
from entmax import sparsemax

torch.use_deterministic_algorithms(True)

x = torch.tensor([-2, 0, 0.5]).to("cuda")
sparsemax(x, dim=0)

Environment

  • Entmax version: 1.3
  • PyTorch version: 2.4.0
  • CUDA version: 12.1
  • Python version: 3.12.4
  • OS: Ubuntu 20.04 LTS

Dependencies

entmax==1.3
└── torch [required: >=1.3, installed: 2.4.0]
    ├── filelock [required: Any, installed: 3.15.4]
    ├── fsspec [required: Any, installed: 2024.6.1]
    ├── Jinja2 [required: Any, installed: 3.1.4]
    │   └── MarkupSafe [required: >=2.0, installed: 2.1.5]
    ├── networkx [required: Any, installed: 3.2.1]
    ├── nvidia-cublas-cu12 [required: ==12.1.3.1, installed: 12.1.3.1]
    ├── nvidia-cuda-cupti-cu12 [required: ==12.1.105, installed: 12.1.105]
    ├── nvidia-cuda-nvrtc-cu12 [required: ==12.1.105, installed: 12.1.105]
    ├── nvidia-cuda-runtime-cu12 [required: ==12.1.105, installed: 12.1.105]
    ├── nvidia-cudnn-cu12 [required: ==9.1.0.70, installed: 9.1.0.70]
    │   └── nvidia-cublas-cu12 [required: Any, installed: 12.1.3.1]
    ├── nvidia-cufft-cu12 [required: ==11.0.2.54, installed: 11.0.2.54]
    ├── nvidia-curand-cu12 [required: ==10.3.2.106, installed: 10.3.2.106]
    ├── nvidia-cusolver-cu12 [required: ==11.4.5.107, installed: 11.4.5.107]
    │   ├── nvidia-cublas-cu12 [required: Any, installed: 12.1.3.1]
    │   ├── nvidia-cusparse-cu12 [required: Any, installed: 12.1.0.106]
    │   │   └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
    │   └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
    ├── nvidia-cusparse-cu12 [required: ==12.1.0.106, installed: 12.1.0.106]
    │   └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
    ├── nvidia-nccl-cu12 [required: ==2.20.5, installed: 2.20.5]
    ├── nvidia-nvtx-cu12 [required: ==12.1.105, installed: 12.1.105]
    ├── setuptools [required: Any, installed: 69.5.1]
    ├── sympy [required: Any, installed: 1.13.2]
    │   └── mpmath [required: >=1.1.0,<1.4, installed: 1.3.0]
    ├── triton [required: ==3.0.0, installed: 3.0.0]
    │   └── filelock [required: Any, installed: 3.15.4]
    └── typing_extensions [required: >=4.8.0, installed: 4.12.2]

Solution

According to this issue, deterministic support for cumsum is resolved in PyTorch 2.6.0, but this version is not released yet.

For older versions, the following workarounds could be considered:

  • Add support for CPU fallback or alternative deterministic algorithm for cumsum.
  • Explicitly document this limitation in the Entmax README or add an explicit warning when running.
@bpopeters
Copy link
Collaborator

Thank you for bringing this to our attention. We'll add a note to the readme about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants