Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider adopting a stateless PRNG API #509

Open
alxmrs opened this issue Jul 21, 2024 · 3 comments
Open

Consider adopting a stateless PRNG API #509

alxmrs opened this issue Jul 21, 2024 · 3 comments

Comments

@alxmrs
Copy link
Contributor

alxmrs commented Jul 21, 2024

While I'm not familiar with the Philox pseudo-random number generator (PRNG) in Numpy (it does look well suited to generation in a distributed setting), I think adopting a stateless PRNG API will be useful long-term for cubed. In addition to working in a parallel/distributed setting, Cubed also has to consider how it can best perform computation with vectorization and hardware acceleration (#304, #490).

I'm quite persuaded by the design of Jax's PRNG system that statelessness (if not also splittable). I belive this approach will prove useful in the long-term.

https://github.com/google/jax/blob/main/docs/jep/263-prng.md

FWIW, I believe any ML framework will have to have special cases for randomness, given the constraints of hardware (GPUs/TPUs).

https://pytorch-dev-podcast.simplecast.com/episodes/random-number-generators

@tomwhite
Copy link
Member

FWIW, I believe any ML framework will have to have special cases for randomness, given the constraints of hardware (GPUs/TPUs).

I agree. That's why random number generation is not a part of the Array API, and almost certainly won't be: data-apis/array-api#431.

For Cubed I think this means that the random number functions are less fixed than the rest of the API, so I'd be open to changing them or adding new ones if we need to. The main use case is for generating test data, so they can be quite simple.

What do you think we need for JAX? Could we write an implementation of cubed.random.random that delegates to JAX (if the backend array API is JAX) - or do we need to have a different signature?

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 23, 2024

I'm not totally sure what's needed for JAX. For now (running on a single machine with a single device), random is working well enough, given we can convert the arrays. However, I suspect that when the hardware arrangement changes (e.g. multiple GPUs per machine), things could go wrong.

The Array API link points something interesting out to me: Namely, that PyTorch uses the same RNG that you're using here. This leaves me a bit hopeful that the problem I'm anticipating could just work itself out. I bet, though, that anytime random is used on Jax arrays, it will have to be functionalized, and thus require a different signature.

@jakirkham
Copy link

If you haven't already, would encourage reading the proposed SPEC 7: scientific-python/specs#180

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants