Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tracing Random.jl functionality correctly #363

Merged
merged 22 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ jobs:
version: '1.10'
assertions: true
test_group: neural_networks
- os: ubuntu-20.04
arch: x64
libReactant: packaged
version: '1.10'
assertions: true
test_group: integration
- os: ubuntu-20.04
arch: x86
libReactant: packaged
Expand Down
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Expand All @@ -23,17 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

Expand All @@ -50,6 +53,8 @@ LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pages = [
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
"Internal API" => "api/internal.md",
],
]

Expand Down
4 changes: 3 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ export default defineConfig({
{ text: "MLIR API", link: "/api/mlirc" },
{ text: "XLA", link: "/api/xla" },
],
}
},
{ text: "Internal API", link: "/api/internal" },
],
},
{
Expand Down Expand Up @@ -132,6 +133,7 @@ export default defineConfig({
{ text: "XLA", link: "/api/xla" },
],
},
{ text: "Internal API", link: "/api/internal" },
],
},
},
Expand Down
12 changes: 12 additions & 0 deletions docs/src/api/internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# Internal API

These functions are not part of the public API and are subject to change at any time.

```@docs
Reactant.REDUB_ARGUMENTS_NAME
Reactant.within_reactant_interpreter
```
11 changes: 11 additions & 0 deletions ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ReactantRandom123Ext

using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"

end
141 changes: 136 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,19 +1016,150 @@ end
end

# random ops
"""
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from a uniform random
distribution between 0 and 1. Returns a NamedTuple with the following fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
)
output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape)
) where {T<:Integer}
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
if algorithm == "PHILOX"
@assert length(seed) ∈ (2, 3)
elseif algorithm == "THREE_FRY"
@assert length(seed) == 2
end

output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64))
rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm)
op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location)
op = stablehlo.rng_bit_generator(
seed.mlir_data; output, output_state, rng_algorithm, location
)
return (;
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape),
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)),
)
end

@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:AbstractFloat}
nbits = sizeof(T) * 8
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
)
return (; output_state, output)
end

"""
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from a standard normal
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
return (; output_state=seed, output=rand_normal)
end

"""
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type `T` with the given shape and seed from an exponential
distribution with rate 1. Returns a NamedTuple with the following fields:

- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.

# Arguments

- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
rand_exp = negate(log_plus_one(negate(rand_uniform)))
return (; output_state=seed, output=rand_exp)
end

# functional ops
Expand Down
95 changes: 94 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Helper Function to determine if we are inside the ReactantInterpreter
"""
within_reactant_interpreter()

Returns `true` if we are currently inside the ReactantInterpreter.
"""
@noinline within_reactant_interpreter() = false
@reactant_overlay @noinline within_reactant_interpreter() = true

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
# Enzyme.jl overlays
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
Expand All @@ -22,3 +31,87 @@ end
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

# Random.jl overlays
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end

## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
## We can't directly overlay that call without breaking the semantics of inplace update
for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:overload_, randfun)
overload_randfun! = Symbol(:overload_, randfun!)

@eval begin
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, dim1::Integer, dims::Integer...
)
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
end

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
end

# inplace
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AnyTracedRArray
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# XXX: Uncomment once AbsInt issues with recursive calls are resolved
# @reactant_overlay @noinline function Random.$(randfun!)(
# rng::AbstractRNG, A::AbstractArray
# )
# @warn "Directly writing to an array using Random.jl functions inside \
# ReactantInterpreter will generate a constant array in the IR. Use with \
# caution." maxlog = 1
# return Random.$(randfun!)(rng, A)
# end
end
end
11 changes: 10 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module Reactant
using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

Expand Down Expand Up @@ -122,7 +124,14 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")

include("linear_algebra.jl")
mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
File renamed without changes.
Loading
Loading