-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP @trace
function calls
#366
base: main
Are you sure you want to change the base?
Changes from 5 commits
e816056
88d613a
17c6c7b
998ef8a
e01fae9
c0214eb
2ec7283
316abad
274bc4f
83b3ffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ module ReactantCore | |
using ExpressionExplorer: ExpressionExplorer | ||
using MacroTools: MacroTools | ||
|
||
using Base.ScopedValues | ||
const enable_tracing = ScopedValue{Bool}(false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we simply want to trace function calls everytime in |
||
|
||
export @trace, MissingTracedValue | ||
|
||
# Traits | ||
|
@@ -115,6 +118,7 @@ macro trace(expr) | |
return esc(trace_if_with_returns(__module__, expr)) | ||
end | ||
end | ||
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr)) | ||
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr)) | ||
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr))) | ||
return error("Only `if-elseif-else` blocks are currently supported by `@trace`") | ||
|
@@ -180,7 +184,7 @@ function trace_for(mod, expr) | |
end | ||
|
||
return quote | ||
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...))) | ||
if $(enable_tracing)[] && $(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...))) | ||
$(reactant_code_block) | ||
else | ||
$(expr) | ||
|
@@ -194,7 +198,7 @@ function trace_if_with_returns(mod, expr) | |
mod, expr.args[2]; store_last_line=expr.args[1], depth=1 | ||
) | ||
return quote | ||
if any($(is_traced), ($(all_check_vars...),)) | ||
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),)) | ||
$(new_expr) | ||
else | ||
$(expr) | ||
|
@@ -340,14 +344,26 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) | |
) | ||
|
||
return quote | ||
if any($(is_traced), ($(all_check_vars...),)) | ||
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),)) | ||
$(reactant_code_block) | ||
else | ||
$(original_expr) | ||
end | ||
end | ||
end | ||
|
||
function trace_call(mod, expr) | ||
f = expr.args[1] | ||
args = expr.args[2:end] | ||
return quote | ||
if $(enable_tracing)[] | ||
$(traced_call)($f, $(args...)) | ||
else | ||
$(expr) | ||
end | ||
end | ||
end | ||
|
||
function remove_shortcircuiting(expr) | ||
return MacroTools.prewalk(expr) do x | ||
if MacroTools.@capture(x, a_ && b_) | ||
|
@@ -371,6 +387,8 @@ function traced_while(cond_fn, body_fn, args) | |
return args | ||
end | ||
|
||
traced_call(f, args...; kwargs...) = f(args...; kwargs...) | ||
|
||
function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) | ||
return MacroTools.postwalk(expr) do x | ||
if x isa Symbol && x ∈ all_vars | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ import ..Reactant: | |
TracedToConcrete, | ||
append_path, | ||
TracedType | ||
using Base.ScopedValues | ||
import ReactantCore: enable_tracing | ||
|
||
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) | ||
@inline traced_getfield( | ||
|
@@ -287,12 +289,16 @@ function compile_mlir(f, args; kwargs...) | |
end | ||
end | ||
|
||
const callcache = ScopedValue{Dict}() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a simple way to store the cache. Fancier solutions (i.e. integrating with absint?) are probably possible as well but not sure whether it's worth it to investigate as of now. |
||
|
||
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) | ||
fnwrapped, | ||
func2, traced_result, result, seen_args, ret, linear_args, in_tys, | ||
linear_results = MLIR.IR.mmodule!(mod) do | ||
MLIR.IR.block!(MLIR.IR.body(mod)) do | ||
return Reactant.make_mlir_fn(f, args, (), "main", true) | ||
with(enable_tracing=>true, callcache=>Dict()) do | ||
return Reactant.make_mlir_fn(f, args, (), "main", true) | ||
end | ||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,82 @@ function ReactantCore.traced_while( | |
end | ||
end | ||
|
||
function ReactantCore.traced_call(f, args...) | ||
# TODO: caching! | ||
cache_key = make_tracer( | ||
Reactant.OrderedIdDict(), | ||
(f, args...), | ||
(), | ||
CallCache; | ||
toscalar=false, | ||
track_numbers=(), # TODO: track_numbers? | ||
) | ||
|
||
if haskey(Reactant.Compiler.callcache[], cache_key) && false | ||
@info "Cache hit" | ||
else | ||
Reactant.Compiler.callcache[][cache_key] = nothing | ||
@warn Reactant.Compiler.callcache[] | ||
N = length(args) | ||
seen_args = Reactant.OrderedIdDict() | ||
traced_args = ntuple(N) do i | ||
return make_tracer( | ||
seen_args, | ||
args[i], | ||
(), | ||
TracedTrack; | ||
toscalar=false, | ||
track_numbers=(), | ||
) | ||
end | ||
linear_args = Reactant.MLIR.IR.Value[] | ||
for (k, v) in seen_args | ||
v isa TracedType || continue | ||
push!(linear_args, v.mlir_data) | ||
end | ||
end | ||
|
||
f_name = String(gensym(Symbol(f))) | ||
temp = Reactant.make_mlir_fn( | ||
f, | ||
args, | ||
(), | ||
f_name, | ||
false; | ||
no_args_in_result=true, | ||
) | ||
|
||
@warn temp | ||
|
||
|
||
traced_result, ret, linear_result = temp[[3, 6, 9]] | ||
|
||
call_op = MLIR.Dialects.func.call( | ||
linear_args; | ||
result_0=[MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)], | ||
callee=MLIR.IR.FlatSymbolRefAttribute(f_name), | ||
) | ||
|
||
seen_results = Reactant.OrderedIdDict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I first copy a traced structure using |
||
traced_result = make_tracer( | ||
seen_results, | ||
traced_result, | ||
(), | ||
TracedSetPath; | ||
toscalar=false, | ||
track_numbers=(), | ||
) | ||
linear_results = TracedType[] | ||
i = 1 | ||
for (k, v) in seen_results | ||
v isa TracedType || continue | ||
v.mlir_data = MLIR.IR.result(call_op, i) | ||
i += 1 | ||
end | ||
|
||
return traced_result | ||
end | ||
|
||
function take_region(compiled_fn) | ||
region = MLIR.IR.Region() | ||
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ScopedValues
where introduced in Julia 1.11, so how about adding the package? i think Base reimplements it and doesn't just reexport, so maybe put an@static if
based on Julia version?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're a step ahead of me, i just added the same comment myself. Is it fine for ReactantCore to depend on ScopedValues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
upstream MLIR already depends on ScopedValues and it's a lightweight dependency, so it's fine from my side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a static if is needed as the ScopedValues package just reexports Base.ScopedValues