-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic function call insertion #523
base: main
Are you sure you want to change the base?
Conversation
…ed values to their corresponding mlir type. These transformed values can be used as keys in a dict (stored in ScopedValue for ease). Cache hits are detected but the cache is not yet used because there is not yet a way to replace the mlir data recursively in a traced object.
Repurposes the path argument of `make_tracer` and builds a vector containing: * MLIR type for traced values * Julia type for objects * actual value for primitive types * `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit
JuliaFormatter
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 585 in d2ce359
ir = @code_hlo optimize=false call1(a_ra, b_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 594 in d2ce359
ir = @code_hlo optimize=false call1(a_ra, c_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 599 in d2ce359
_call2(a) = a+a |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 629 in d2ce359
ir = @code_hlo optimize=false call3(y_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 641 in d2ce359
_call4(foobar::Union{Foo, Bar}) = foobar.x |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 654 in d2ce359
ir = @code_hlo optimize=false call4(foo, foo2, bar) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 658 in d2ce359
@@ -352,7 +352,7 @@ const cuLaunch = Ref{UInt}(0) | |||
const cuFunc = Ref{UInt}(0) | |||
const cuModule = Ref{UInt}(0) | |||
|
|||
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) | |||
function compile_mlir!(mod, f, args, callcache=Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any}}(); optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function compile_mlir!(mod, f, args, callcache=Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any}}(); optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) | |
function compile_mlir!( | |
mod, | |
f, | |
args, | |
callcache=Dict{ | |
Vector, | |
@NamedTuple{ | |
f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any | |
} | |
}(); | |
optimize::Union{Bool,Symbol}=true, | |
no_nan::Bool=false, | |
) |
@@ -361,7 +361,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||
fnwrapped, | |||
func2, traced_result, result, seen_args, ret, linear_args, in_tys, | |||
linear_results = try | |||
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) | |||
callcache!(callcache) do # TODO: don't create a closure here either. | |||
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) | |
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) |
end | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
v isa TracedType || continue | ||
push!(linear_args, v.mlir_data) | ||
# make tracer inserted `()` into the path, here we remove it: | ||
v.paths = v.paths[1:end-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
v.paths = v.paths[1:end-1] | |
v.paths = v.paths[1:(end - 1)] |
f, | ||
args, | ||
(), | ||
f_name, | ||
false; | ||
no_args_in_result=true, | ||
do_transpose=false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
f, | |
args, | |
(), | |
f_name, | |
false; | |
no_args_in_result=true, | |
do_transpose=false, | |
f, args, (), f_name, false; no_args_in_result=true, do_transpose=false |
src/utils.jl
Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) | |||
end | |||
end | |||
elseif Base.invokelatest(should_rewrite_ft, ft) | |||
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant | |
new_f = if (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) | |
call_with_reactant | |
else | |
traced_call_with_reactant | |
end |
src/utils.jl
Outdated
@@ -279,6 +297,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) | |||
min_world = Ref{UInt}(typemin(UInt)) | |||
max_world = Ref{UInt}(typemax(UInt)) | |||
|
|||
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant | |
new_f = if (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) | |
call_with_reactant | |
else | |
traced_call_with_reactant | |
end |
src/utils.jl
Outdated
sig2 = Tuple{ | ||
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... | ||
typeof(new_f), | ||
sig.parameters[1:(end - 1)]..., | ||
ns..., | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sig2 = Tuple{ | |
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... | |
typeof(new_f), | |
sig.parameters[1:(end - 1)]..., | |
ns..., | |
} | |
sig2 = Tuple{typeof(new_f),sig.parameters[1:(end - 1)]...,ns...} |
src/utils.jl
Outdated
any_changed = false | ||
for (i, inst) in enumerate(ir.stmts) | ||
# Explicitly skip any code which returns Union{} so that we throw the error | ||
# instead of risking a segfault | ||
RT = inst[:type] | ||
@static if VERSION < v"1.11" | ||
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error) | ||
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing) | |
changed, next, RT = rewrite_inst( | |
inst[:inst], ir, interp, RT, guaranteed_error, allow_tracing | |
) |
src/utils.jl
Outdated
Core.Compiler.setindex!(ir.stmts[i], next, :inst) | ||
else | ||
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error) | ||
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing) | |
changed, next, RT = rewrite_inst( | |
inst[:stmt], ir, interp, RT, guaranteed_error, allow_tracing | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/utils.jl
Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) | |||
end | |||
end | |||
elseif Base.invokelatest(should_rewrite_ft, ft) | |||
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I see what's happening here, but I feel like we should do this a bit differently.
Presently you're rewriting all calls to be traced_call_with_reactant.
What if, instead, we modified the codeinfo and/or opaque closure within call_with_reactant. That way we don't have an extra level of indirection (that may cause issues)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have the methodinstance/codeinfo on the inside (and argtypes themselves) so we could even do the equivalent of make_Tracer into a compile-time recusion (aka generate the equivalent of
key = (arg1.x.y, arg2.z, arg3[4], ...)
during the generated function, so then the (relatively expensive) make_tracer isn't called every function call (which is already quite expensive).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also therefore if a function just is foo(TracedArray, ) we literally don't even need to do a key check!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding:
8 1 ─ %1 = invoke Main._call1(_2::Reactant.TracedRArray{Float64, 2}, _3::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
9 │ %2 = invoke Main._call1(%1::Reactant.TracedRArray{Float64, 2}, %1::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
└── return %2
would need to be rewritten to:
%1 = (call_with_reactant)(traced_call, _call1, _2, _3)
%2 = (call_with_reactant)(traced_call, _call1, %1, %1)
return %2
instead of (traced_call_with_reactant)(_call1, ...)
, or do you mean to remove more indirection still?
so we could even do the equivalent of make_Tracer into a compile-time recusion
Maybe I'm misunderstanding your point, but I don't think we can fully do the equivalent of make_tracer
at compile-time?
e.g.
struct A
x #untyped
end
For arguments of type A
we can generate arg.x
but we can't go deeper than that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it depends on the type, maybe we can talk about it tomorrow/over the weekend
To make it easier to use those parts in `call_with_reactant_generator`.
Caching is not yet enabled, and arguments aren't passed correctly yet.
0dacb36
to
a27294e
Compare
A problem when tracing through broadcasting: function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end This creates an "temporarily invalid" Possible solutions:
To me, the first approach seems more correct, in the sense that the implementation for |
first seems cleaner to me |
on top of: #366, currently has an error:
With debug printing eventually just repeating:
As if there's a cycle in the callgraph?