Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tracing Random.jl functionality correctly #363

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

avik-pal
Copy link
Collaborator

No description provided.

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 11, 2024

julia> using Reactant, Random

julia> fn() = randn(Random.default_rng(), 2, 3)
fn (generic function with 1 method)

julia> @code_hlo optimize = false fn()
module {
  func.func @main() -> tensor<3x2xf64> {
    %c = stablehlo.constant dense<[9454987348304227925, 11257230962712577529]> : tensor<2xui64>
    %output_state, %output = stablehlo.rng_bit_generator %c, algorithm =  DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
    %0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
    %cst = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
    %1 = stablehlo.divide %0, %cst : tensor<2x3xf64>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
    %2 = stablehlo.multiply %1, %cst_0 : tensor<2x3xf64>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
    %3 = stablehlo.subtract %2, %cst_1 : tensor<2x3xf64>
    %4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
    %cst_2 = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
    %5 = stablehlo.multiply %4, %cst_2 : tensor<2x3xf64>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %6 : tensor<3x2xf64>
  }
}

julia> @code_hlo fn()
module {
  func.func @main() -> tensor<3x2xf64> {
    %cst = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
    %cst_1 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
    %cst_2 = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
    %c = stablehlo.constant dense<[17523564455668573441, 5342821220909967229]> : tensor<2xui64>
    %output_state, %output = stablehlo.rng_bit_generator %c, algorithm =  DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
    %0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
    %1 = stablehlo.divide %0, %cst_2 : tensor<2x3xf64>
    %2 = stablehlo.multiply %1, %cst_1 : tensor<2x3xf64>
    %3 = stablehlo.subtract %2, %cst_0 : tensor<2x3xf64>
    %4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
    %5 = stablehlo.multiply %4, %cst : tensor<2x3xf64>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %6 : tensor<3x2xf64>
  }
}

src/stdlibs/Random.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Collaborator Author

This is kind of working now. Can I get an initial review?

Comment on lines +103 to +111
if (f === Random.default_rng || f === default_rng) && length(argtypes) == 1
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing : Any[:($(default_rng_inside_interpreter))],
Any[Core.Const(default_rng_inside_interpreter)],
)
return abstract_call_known(
interp, default_rng_inside_interpreter, arginfo2, si, sv, max_methods
)
end
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be updated once the new CUDA interpreter stuff lands

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant