Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic function call insertion #523

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Jan 13, 2025

on top of: #366, currently has an error:

using Reactant

Reactant.DEBUG_INTERP[] = true
Reactant.TOGGLE_TRACECALLS[] = true # necessary to avoid precompilation from failing

@noinline _call1(a, b) = a
function call1(a, b)
    x = _call1(a, b)
    y = _call1(a, b)
    return _call1(x, y)
end

a = rand(2, 3)
b = rand(2, 3)
a_ra = Reactant.to_rarray(a)
b_ra = Reactant.to_rarray(b)

@compile(call1(a_ra, b_ra))
ERROR: StackOverflowError:
Stacktrace:
     [1] make_typealias(x::Type)
       @ Base ./show.jl:644
     [2] show_typealias(io::IOBuffer, x::Type)
       @ Base ./show.jl:805
     [3] _show_type(io::IOBuffer, x::Type)
       @ Base ./show.jl:970
     [4] show(io::IOBuffer, x::Type)
       @ Base ./show.jl:965
     [5] show_typeparams(io::IOBuffer, env::Core.SimpleVector, orig::Core.SimpleVector, wheres::Vector{TypeVar})
       @ Base ./show.jl:722
     [6] show_datatype(io::IOBuffer, x::DataType, wheres::Vector{TypeVar})
       @ Base ./show.jl:1181
     [7] show_datatype
       @ ./show.jl:1089 [inlined]
     [8] _show_type(io::IOBuffer, x::Type)
       @ Base ./show.jl:973
     [9] show(io::IOBuffer, x::Type)
       @ Base ./show.jl:965
    [10] _show_default(io::IOBuffer, x::Any)
       @ Base ./show.jl:486
    [11] show_default
       @ ./show.jl:482 [inlined]
    [12] show
       @ ./show.jl:477 [inlined]
    [13] print(io::IOBuffer, x::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Base ./strings/io.jl:35
    [14] print_to_string(xs::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Base ./strings/io.jl:148
    [15] string
       @ ./strings/io.jl:189 [inlined]
    [16] safe_print
       @ ~/Reactant.jl/src/utils.jl:448 [inlined]
    [17] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:27 [inlined]
    [18] call_with_reactant(::Type{OrderedCollections.OrderedDict{…}}, ::Base.Generator{Vector{…}, Reactant.var"#1#3"})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [19] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:8 [inlined]
    [20] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:16 [inlined]
    [21] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:15 [inlined]
    [22] traced_call
       @ ~/Reactant.jl/src/ControlFlow.jl:134 [inlined]
    [23] traced_call(none::typeof(memoryref), none::Tuple{Memory{UInt64}})
       @ Reactant ./<missing>:0
    [24] GenericMemory
       @ ./boot.jl:514 [inlined]
    [25] Array
       @ ./boot.jl:578 [inlined]
    [26] getindex
       @ ./array.jl:400 [inlined]
    [27] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:16 [inlined]
    [28] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:15 [inlined]
    [29] traced_call
       @ ~/Reactant.jl/src/ControlFlow.jl:134 [inlined]
    [30] call_with_reactant(::typeof(ReactantCore.traced_call), ::typeof(memoryref), ::Memory{UInt64})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [31] traced_call_with_reactant(f::Function, args::Memory{UInt64})
       @ Reactant ~/Reactant.jl/src/utils.jl:19
    [32] Array
       @ ./boot.jl:579 [inlined]
    [33] Array
       @ ./boot.jl:601 [inlined]
    [34] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:23 [inlined]
    [35] OrderedCollections.OrderedDict{UInt64, Any}()
       @ Reactant ./<missing>:0
    [36] GenericMemory
       @ ./boot.jl:516 [inlined]
    [37] Array
       @ ./boot.jl:578 [inlined]
    [38] Array
       @ ./boot.jl:591 [inlined]
    [39] zeros
       @ ./array.jl:589 [inlined]
    [40] zeros
       @ ./array.jl:585 [inlined]
    [41] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:23 [inlined]
    [42] call_with_reactant(redub_arguments#232::Type{OrderedCollections.OrderedDict{UInt64, Any}})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [43] traced_call
       @ ~/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:397 [inlined]
    [44] traced_call(none::Type{OrderedCollections.OrderedDict{UInt64, Any}}, none::Tuple{})
       @ Reactant ./<missing>:0
    [45] traced_call
       @ ~/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:397 [inlined]
    [46] call_with_reactant(::typeof(ReactantCore.traced_call), ::Type{OrderedCollections.OrderedDict{UInt64, Any}})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [47] traced_call_with_reactant(::Type)
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [48] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:27 [inlined]
    [49] OrderedCollections.OrderedDict{UInt64, Any}(none::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Reactant ./<missing>:0
--- the above 33 lines are repeated 2284 more times ---
...

With debug printing eventually just repeating:

...
"fn arg[1] traced_call"
"fn arg[2] memoryref"
"fn arg[3] (UInt64[],)"
"fn arg[1] memoryref"
"fn arg[2] Pair{Any, Any}[]"
"fn arg[1] OrderedCollections.OrderedDict{UInt64, Any}"
"fn arg[2] Base.Generator{Vector{Pair{Any, Any}}, Reactant.var\"#1#3\"}(Reactant.var\"#1#3\"(), Pair{Any, Any}[])"
"fn arg[1] traced_call"
"fn arg[2] OrderedCollections.OrderedDict{UInt64, Any}"
"fn arg[3] ()"
"fn arg[1] OrderedCollections.OrderedDict{UInt64, Any}"
...

As if there's a cycle in the callgraph?

jumerckx and others added 25 commits January 2, 2025 14:53
…ed values to their corresponding mlir type. These transformed values can be used as keys in a dict (stored in ScopedValue for ease).

Cache hits are detected but the cache is not yet used because there is not yet a way to replace the mlir data recursively in a traced object.
Repurposes the path argument of `make_tracer` and builds a vector containing:
* MLIR type for traced values
* Julia type for objects
* actual value for primitive types
* `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

JuliaFormatter

[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call1(a_ra, b_ra)


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call1(a_ra, c_ra)


[JuliaFormatter] reported by reviewdog 🐶

_call2(a) = a+a


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call3(y_ra)


[JuliaFormatter] reported by reviewdog 🐶

_call4(foobar::Union{Foo, Bar}) = foobar.x


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call4(foo, foo2, bar)


[JuliaFormatter] reported by reviewdog 🐶

@@ -352,7 +352,7 @@ const cuLaunch = Ref{UInt}(0)
const cuFunc = Ref{UInt}(0)
const cuModule = Ref{UInt}(0)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(mod, f, args, callcache=Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any}}(); optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function compile_mlir!(mod, f, args, callcache=Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any}}(); optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(
mod,
f,
args,
callcache=Dict{
Vector,
@NamedTuple{
f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any
}
}();
optimize::Union{Bool,Symbol}=true,
no_nan::Bool=false,
)

@@ -361,7 +361,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
callcache!(callcache) do # TODO: don't create a closure here either.
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)

end
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

v isa TracedType || continue
push!(linear_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
v.paths = v.paths[1:end-1]
v.paths = v.paths[1:(end - 1)]

Comment on lines +161 to +167
f,
args,
(),
f_name,
false;
no_args_in_result=true,
do_transpose=false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
f,
args,
(),
f_name,
false;
no_args_in_result=true,
do_transpose=false,
f, args, (), f_name, false; no_args_in_result=true, do_transpose=false

src/utils.jl Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
end
end
elseif Base.invokelatest(should_rewrite_ft, ft)
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
new_f = if (!allow_tracing || ft <: typeof(ReactantCore.traced_call))
call_with_reactant
else
traced_call_with_reactant
end

src/utils.jl Outdated
@@ -279,6 +297,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
new_f = if (!allow_tracing || ft <: typeof(ReactantCore.traced_call))
call_with_reactant
else
traced_call_with_reactant
end

src/utils.jl Outdated
Comment on lines 326 to 312
sig2 = Tuple{
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns...
typeof(new_f),
sig.parameters[1:(end - 1)]...,
ns...,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sig2 = Tuple{
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns...
typeof(new_f),
sig.parameters[1:(end - 1)]...,
ns...,
}
sig2 = Tuple{typeof(new_f),sig.parameters[1:(end - 1)]...,ns...}

src/utils.jl Outdated
any_changed = false
for (i, inst) in enumerate(ir.stmts)
# Explicitly skip any code which returns Union{} so that we throw the error
# instead of risking a segfault
RT = inst[:type]
@static if VERSION < v"1.11"
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error)
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing)
changed, next, RT = rewrite_inst(
inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing
)

src/utils.jl Outdated
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error)
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing)
changed, next, RT = rewrite_inst(
inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/Tracing.jl

Lines 372 to 379 in d2ce359

seen,
xi,
newpath,
mode;
toscalar,
tobatch,
track_numbers,
kwargs...,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/Tracing.jl

Lines 409 to 416 in d2ce359

seen,
xi,
newpath,
mode;
toscalar,
tobatch,
track_numbers,
kwargs...,

src/utils.jl Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
end
end
elseif Base.invokelatest(should_rewrite_ft, ft)
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I see what's happening here, but I feel like we should do this a bit differently.

Presently you're rewriting all calls to be traced_call_with_reactant.

What if, instead, we modified the codeinfo and/or opaque closure within call_with_reactant. That way we don't have an extra level of indirection (that may cause issues)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have the methodinstance/codeinfo on the inside (and argtypes themselves) so we could even do the equivalent of make_Tracer into a compile-time recusion (aka generate the equivalent of

key = (arg1.x.y, arg2.z, arg3[4], ...)

during the generated function, so then the (relatively expensive) make_tracer isn't called every function call (which is already quite expensive).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also therefore if a function just is foo(TracedArray, ) we literally don't even need to do a key check!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding:

8 1 ─ %1 = invoke Main._call1(_2::Reactant.TracedRArray{Float64, 2}, _3::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
9 │   %2 = invoke Main._call1(%1::Reactant.TracedRArray{Float64, 2}, %1::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
  └──      return %2

would need to be rewritten to:

%1 = (call_with_reactant)(traced_call, _call1, _2, _3)
%2 = (call_with_reactant)(traced_call, _call1, %1, %1)
return %2

instead of (traced_call_with_reactant)(_call1, ...), or do you mean to remove more indirection still?

so we could even do the equivalent of make_Tracer into a compile-time recusion

Maybe I'm misunderstanding your point, but I don't think we can fully do the equivalent of make_tracer at compile-time?
e.g.

struct A
    x #untyped
end

For arguments of type A we can generate arg.x but we can't go deeper than that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it depends on the type, maybe we can talk about it tomorrow/over the weekend

@jumerckx jumerckx force-pushed the jm/funccal_insertion branch from 0dacb36 to a27294e Compare January 22, 2025 09:38
@jumerckx
Copy link
Collaborator Author

A problem when tracing through broadcasting:
Base.similar for broadcasted objects is implemented as

function Base.similar(
    ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
    @assert N isa Int
    return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

This creates an "temporarily invalid" TracedRArray since there's no mlir data.
With regular tracing, this was no problem because these calls are always paired with Base.copyto! which injects the correct MLIR data.
When trying to generate the calls separately, however, there is no actual mlir value which can be used to be passed to the mlir version of copyto!.

Possible solutions:

  • make similar for broadcasted objects return an actual mlir value (fill(0, ...))
  • automatically replace "invalid" objects with a tensor of zeros at the point where they would be used as arguments for a call
  • ...?

To me, the first approach seems more correct, in the sense that the implementation for similar of broadcasted objects can actually be used as a standalone function.

@wsmoses
Copy link
Member

wsmoses commented Jan 22, 2025

first seems cleaner to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants