Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP @trace function calls #366

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
24 changes: 21 additions & 3 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

using Base.ScopedValues
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScopedValues where introduced in Julia 1.11, so how about adding the package? i think Base reimplements it and doesn't just reexport, so maybe put an @static if based on Julia version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're a step ahead of me, i just added the same comment myself. Is it fine for ReactantCore to depend on ScopedValues?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upstream MLIR already depends on ScopedValues and it's a lightweight dependency, so it's fine from my side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a static if is needed as the ScopedValues package just reexports Base.ScopedValues

const enable_tracing = ScopedValue{Bool}(false)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we simply want to trace function calls everytime in make_mlir_fn. I added this scopedvalue that is set in the beginning so it's just a simple check.
If it's okay to use ScopedValues in ReactantCore, ScopedValues probably needs to be added to project.toml for backwards compatibility (?)


export @trace, MissingTracedValue

# Traits
Expand Down Expand Up @@ -115,6 +118,7 @@ macro trace(expr)
return esc(trace_if_with_returns(__module__, expr))
end
end
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
Expand Down Expand Up @@ -180,7 +184,7 @@ function trace_for(mod, expr)
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if $(enable_tracing)[] && $(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
$(reactant_code_block)
else
$(expr)
Expand All @@ -194,7 +198,7 @@ function trace_if_with_returns(mod, expr)
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
return quote
if any($(is_traced), ($(all_check_vars...),))
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
$(expr)
Expand Down Expand Up @@ -340,14 +344,26 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)

return quote
if any($(is_traced), ($(all_check_vars...),))
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
end
end
end

function trace_call(mod, expr)
f = expr.args[1]
args = expr.args[2:end]
return quote
if $(enable_tracing)[]
$(traced_call)($f, $(args...))
else
$(expr)
end
end
end

function remove_shortcircuiting(expr)
return MacroTools.prewalk(expr) do x
if MacroTools.@capture(x, a_ && b_)
Expand All @@ -371,6 +387,8 @@ function traced_while(cond_fn, body_fn, args)
return args
end

traced_call(f, args...; kwargs...) = f(args...; kwargs...)

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if x isa Symbol && x ∈ all_vars
Expand Down
8 changes: 7 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import ..Reactant:
TracedToConcrete,
append_path,
TracedType
using Base.ScopedValues
import ReactantCore: enable_tracing

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
Expand Down Expand Up @@ -287,12 +289,16 @@ function compile_mlir(f, args; kwargs...)
end
end

const callcache = ScopedValue{Dict}()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a simple way to store the cache. Fancier solutions (i.e. integrating with absint?) are probably possible as well but not sure whether it's worth it to investigate as of now.


function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
return Reactant.make_mlir_fn(f, args, (), "main", true)
with(enable_tracing=>true, callcache=>Dict()) do
return Reactant.make_mlir_fn(f, args, (), "main", true)
end
end
end

Expand Down
76 changes: 76 additions & 0 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,82 @@ function ReactantCore.traced_while(
end
end

function ReactantCore.traced_call(f, args...)
# TODO: caching!
cache_key = make_tracer(
Reactant.OrderedIdDict(),
(f, args...),
(),
CallCache;
toscalar=false,
track_numbers=(), # TODO: track_numbers?
)

if haskey(Reactant.Compiler.callcache[], cache_key) && false
@info "Cache hit"
else
Reactant.Compiler.callcache[][cache_key] = nothing
@warn Reactant.Compiler.callcache[]
N = length(args)
seen_args = Reactant.OrderedIdDict()
traced_args = ntuple(N) do i
return make_tracer(
seen_args,
args[i],
(),
TracedTrack;
toscalar=false,
track_numbers=(),
)
end
linear_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_args
v isa TracedType || continue
push!(linear_args, v.mlir_data)
end
end

f_name = String(gensym(Symbol(f)))
temp = Reactant.make_mlir_fn(
f,
args,
(),
f_name,
false;
no_args_in_result=true,
)

@warn temp


traced_result, ret, linear_result = temp[[3, 6, 9]]

call_op = MLIR.Dialects.func.call(
linear_args;
result_0=[MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)],
callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
)

seen_results = Reactant.OrderedIdDict()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first copy a traced structure using make_tracer, TracedSetPath. This structure can I then fill in by going by mutating the objects in seen. This seems to work well but maybe someone has a better approach.

traced_result = make_tracer(
seen_results,
traced_result,
(),
TracedSetPath;
toscalar=false,
track_numbers=(),
)
linear_results = TracedType[]
i = 1
for (k, v) in seen_results
v isa TracedType || continue
v.mlir_data = MLIR.IR.result(call_op, i)
i += 1
end

return traced_result
end

function take_region(compiled_fn)
region = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
Expand Down
4 changes: 4 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TracedToConcrete = 3
ArrayToConcrete = 4
TracedSetPath = 5
CallCache = 6
end

for T in (
Expand Down Expand Up @@ -382,6 +383,9 @@ function make_tracer(
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
end
if mode == CallCache
return MLIR.IR.type(prev.mlir_data)
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
if !haskey(seen, prev)
Expand Down
Loading