-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider adopting a stateless PRNG API #509
Comments
I agree. That's why random number generation is not a part of the Array API, and almost certainly won't be: data-apis/array-api#431. For Cubed I think this means that the random number functions are less fixed than the rest of the API, so I'd be open to changing them or adding new ones if we need to. The main use case is for generating test data, so they can be quite simple. What do you think we need for JAX? Could we write an implementation of |
I'm not totally sure what's needed for JAX. For now (running on a single machine with a single device), random is working well enough, given we can convert the arrays. However, I suspect that when the hardware arrangement changes (e.g. multiple GPUs per machine), things could go wrong. The Array API link points something interesting out to me: Namely, that PyTorch uses the same RNG that you're using here. This leaves me a bit hopeful that the problem I'm anticipating could just work itself out. I bet, though, that anytime random is used on Jax arrays, it will have to be functionalized, and thus require a different signature. |
If you haven't already, would encourage reading the proposed SPEC 7: scientific-python/specs#180 |
While I'm not familiar with the Philox pseudo-random number generator (PRNG) in Numpy (it does look well suited to generation in a distributed setting), I think adopting a stateless PRNG API will be useful long-term for cubed. In addition to working in a parallel/distributed setting, Cubed also has to consider how it can best perform computation with vectorization and hardware acceleration (#304, #490).
I'm quite persuaded by the design of Jax's PRNG system that statelessness (if not also splittable). I belive this approach will prove useful in the long-term.
https://github.com/google/jax/blob/main/docs/jep/263-prng.md
FWIW, I believe any ML framework will have to have special cases for randomness, given the constraints of hardware (GPUs/TPUs).
https://pytorch-dev-podcast.simplecast.com/episodes/random-number-generators
The text was updated successfully, but these errors were encountered: