-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: tracing Random.jl functionality correctly #363
Draft
avik-pal
wants to merge
6
commits into
main
Choose a base branch
from
ap/random_numbers
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ae7ce38
refactor: move stdlib overloads to a different directory
avik-pal 3950175
fix: Ops.rng_bit_generator
avik-pal aef05e7
feat: initial prototype for random number generation
avik-pal b9411d7
feat: add support for scalar sampling
avik-pal 8673389
feat: efficient sampling for non-native RNGs
avik-pal 7d17faf
fix: handling floating point sampling
avik-pal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Implementation based on the following: | ||
# 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl | ||
# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125 | ||
|
||
mutable struct TracedRNG <: Random.AbstractRNG | ||
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} | ||
const algorithm::String | ||
end | ||
|
||
function Random.seed!(rng::TracedRNG, seed::Number) | ||
seed = reinterpret(UInt64, Random.hash_seed(seed)) | ||
# TODO: Using `seed!` inside tracing should generate a TracedRArray | ||
return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)])) | ||
end | ||
|
||
function Random.seed!( | ||
rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} | ||
) | ||
rng.seed = seed | ||
return rng | ||
end | ||
|
||
make_seed() = rand(Random.RandomDevice(), UInt64, 2) | ||
|
||
TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) | ||
TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") | ||
|
||
default_rng() = TracedRNG() | ||
function default_rng_inside_interpreter() | ||
return TracedRNG(promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") | ||
end | ||
|
||
# XXX: Currently we get an illegal instruction if we don't call Random.default_rng() | ||
|
||
function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} | ||
length(A) == 0 && return A | ||
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) | ||
rng.seed = res.output_state | ||
set_mlir_data!(A, res.output.mlir_data) | ||
return A | ||
end | ||
|
||
function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} | ||
length(A) == 0 && return A | ||
Random.rand!(rng, A) | ||
scaled_uniform = Ops.subtract( | ||
Ops.multiply(A, Ops.constant(fill(T(2), size(A)))), | ||
Ops.constant(fill(T(1), size(A))), | ||
) | ||
probit = Ops.erf_inv(scaled_uniform) | ||
rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A)))) | ||
set_mlir_data!(A, rand_normal.mlir_data) | ||
return A | ||
end | ||
|
||
for randfun in (:rand, :randn) | ||
randfun! = Symbol(randfun, :!) | ||
@eval begin | ||
function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T} | ||
return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims)) | ||
end | ||
|
||
function Random.$(randfun)(rng::TracedRNG, dims::Dims) | ||
return Random.$(randfun)(rng, Float64, dims) | ||
end | ||
|
||
function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) | ||
return Random.$(randfun)(rng, Dims((dim1, dims...))) | ||
end | ||
|
||
function Random.$(randfun)( | ||
rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... | ||
) where {T} | ||
return Random.$(randfun)(rng, T, Dims((dim1, dims...))) | ||
end | ||
|
||
Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A) | ||
|
||
# scalars | ||
function Random.$(randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} | ||
A = promote_to(TracedRArray{T,0}, fill(T(0))) | ||
Random.$(randfun!)(rng, A) | ||
return A[] | ||
end | ||
|
||
# Non-Traced RNGs if used will lead to disastrous performance. We attempt to fix | ||
# that but with a warning | ||
function Random.$(randfun!)(rng::Random.AbstractRNG, A::AnyTracedRArray) | ||
@warn "`rng` is not a `TracedRNG`. We will use this to seed the `TracedRNG` \ | ||
instead of generating samples from this RNG type." maxlog = 1 | ||
seed = promote_to(TracedRArray{UInt64,1}, rand(rng, UInt64, 2)) | ||
trng = TracedRNG(seed, "DEFAULT") | ||
return Random.$(randfun!)(trng, A) | ||
end | ||
end | ||
end | ||
|
||
# resolve ambiguities | ||
function Random.randn(rng::TracedRNG, T::Random.BitFloatType) | ||
A = promote_to(TracedRArray{T,0}, fill(T(0))) | ||
Random.randn!(rng, A) | ||
return A[] | ||
end | ||
|
||
# TODO: At some later point we might want to implement the sampler API as well since it | ||
# makes all RNG implementation work by default. From the post-optimize IR we need to | ||
# confirm that the dynamic_update_slice calls are optimized away into a single | ||
# `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on | ||
# how the seeding should work? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will have to be updated once the new CUDA interpreter stuff lands