Skip to content

Commit

Permalink
refactor: rework traced_if to avoid polluting make_mlir_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 9d666f8 commit 3006ae5
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 92 deletions.
9 changes: 6 additions & 3 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ using MacroTools: MacroTools

export @trace, MissingTracedValue

const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

# Traits
is_traced(x) = false

Expand Down Expand Up @@ -79,7 +83,8 @@ You need to ensure that all branches have the same type.
### Certain Symbols are Reserved
Symbols like `nothing`, `missing` and `:` are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions.
While certain cases might work but these are not guaranteed to work. For
example, the following will not work:
```julia
Expand Down Expand Up @@ -299,6 +304,4 @@ function error_if_return(expr)
end
end

const SPECIAL_SYMBOLS = [:(:), :nothing, :missing]

end
99 changes: 55 additions & 44 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
@@ -1,83 +1,94 @@
function ReactantCore.traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.make_mlir_fn(
true_fn,
args,
(),
string(gensym("true_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
construct_function_without_args=true,
)
function ReactantCore.traced_if(cond::TracedRNumber{Bool}, true_fn, false_fn, args)
# NOTE: This behavior is different from how we compile other functions, i.e., we keep
# things as constants if possible, but from a block we do need to return a
# traced value, so we force a conversion to a TracedType.
# XXX: Eventually we would want to support nested structures as block arguments
# but we will have to do a flatten/unflatten pass to make this work.
args_traced = map(args) do arg
arg isa TracedType && return arg
arg isa Number && return promote_to(TracedRNumber{eltype(arg)}, arg)
arg isa AbstractArray &&
return promote_to(TracedRArray{eltype(arg),ndims(arg)}, arg)
@warn "Argument $(arg) is not a TracedType, TracedRNumber, or TracedRArray. It \
will be promoted to a TracedRNumber. Please open an issue in Reactant.jl \
with an example of this behavior."
return arg
end

(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.make_mlir_fn(
false_fn,
args,
(),
string(gensym("false_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
construct_function_without_args=true,
)
true_block = MLIR.IR.Block()
true_res = MLIR.IR.block!(true_block) do
results = map(true_fn(args_traced...)) do r
r isa TracedType && return r
r isa Number && return promote_to(TracedRNumber{typeof(r)}, r)
r isa AbstractArray &&
return promote_to(TracedRArray{eltype(r),ndims(r)}, r)
error("Unsupported return type $(typeof(r))")
end
MLIR.Dialects.stablehlo.return_([x.mlir_data for x in results])
return results
end

@assert length(true_branch_results) == length(false_branch_results) "true branch returned $(length(true_branch_results)) results, false branch returned $(length(false_branch_results)). This shouldn't happen."
false_block = MLIR.IR.Block()
false_res = MLIR.IR.block!(false_block) do
results = map(false_fn(args_traced...)) do r
r isa TracedType && return r
r isa Number && return promote_to(TracedRNumber{typeof(r)}, r)
r isa AbstractArray &&
return promote_to(TracedRArray{eltype(r),ndims(r)}, r)
error("Unsupported return type $(typeof(r))")
end
MLIR.Dialects.stablehlo.return_([x.mlir_data for x in results])
return results
end

@assert length(true_res) == length(false_res) "true branch returned $(length(true_res)) results, false branch returned $(length(false_res)). This shouldn't happen."

result_types = MLIR.IR.Type[]
linear_results = []
true_block_insertions = []
false_block_insertions = []
for (i, (tr, fr)) in enumerate(zip(true_branch_results, false_branch_results))
for (i, (tr, fr)) in enumerate(zip(true_res, false_res))
if typeof(tr) != typeof(fr)
if !(tr isa MissingTracedValue) && !(fr isa MissingTracedValue)
error("Result #$(i) for the branches have different types: true branch \
returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.")
elseif tr isa MissingTracedValue
push!(result_types, MLIR.IR.type(fr.mlir_data))
push!(linear_results, new_traced_value(false_linear_results[i]))
push!(true_block_insertions, (i => linear_results[end]))
push!(true_block_insertions, (i => new_traced_value(false_res[i])))
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(true_linear_results[i]))
push!(false_block_insertions, (i => linear_results[end]))
push!(false_block_insertions, (i => new_traced_value(true_res[i])))
end
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(tr))
end
end

# Replace all uses of missing values with the correct values
true_branch_region = get_region_removing_missing_values(
true_branch_compiled, true_block_insertions
true_block, true_block_insertions
)

false_branch_region = get_region_removing_missing_values(
false_branch_compiled, false_block_insertions
false_block, false_block_insertions
)

MLIR.IR.rmfromparent!(true_branch_compiled)
MLIR.IR.rmfromparent!(false_branch_compiled)

if_compiled = MLIR.Dialects.stablehlo.if_(
cond.mlir_data;
true_branch=true_branch_region,
false_branch=false_branch_region,
result_0=result_types,
)

return map(enumerate(linear_results)) do (i, res)
res.mlir_data = MLIR.IR.result(if_compiled, i)
return res
return map(1:MLIR.IR.nresults(if_compiled)) do i
res = MLIR.IR.result(if_compiled, i)
sz = size(MLIR.IR.type(res))
T = MLIR.IR.julia_type(eltype(MLIR.IR.type(res)))
isempty(sz) && return TracedRNumber{T}((), res)
return TracedRArray{T,length(sz)}((), res, sz)
end
end

function get_region_removing_missing_values(compiled_fn, insertions)
function get_region_removing_missing_values(block, insertions)
region = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false)
push!(region, block)
return_op = MLIR.IR.terminator(block)
for (i, rt) in insertions
if rt isa TracedRNumber
Expand Down
5 changes: 5 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ include("ConcreteRArray.jl")
include("TracedRNumber.jl")
include("TracedRArray.jl")

function Base.getproperty(x::MissingTracedValue, s::Symbol)
s === :mlir_data && return broadcast_to_size(false, ()).mlir_data
return getfield(x, s)
end

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

include("ControlFlow.jl")
Expand Down
54 changes: 9 additions & 45 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,13 @@ function apply(f, args...; kwargs...)
end

function make_mlir_fn(
f,
args,
kwargs,
name="main",
concretein=true;
toscalar=false,
return_dialect=:func,
no_args_in_result::Bool=false,
construct_function_without_args::Bool=false,
f, args, kwargs, name="main", concretein=true; toscalar=false, return_dialect=:func
)
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
return (
true,
make_mlir_fn(
apply,
(f, args...),
kwargs,
name,
concretein;
toscalar,
return_dialect,
no_args_in_result,
construct_function_without_args,
apply, (f, args...), kwargs, name, concretein; toscalar, return_dialect
)[2:end]...,
)
end
Expand All @@ -70,7 +54,6 @@ function make_mlir_fn(
(:args, i),
concretein ? ConcreteToTraced : TracedSetPath;
toscalar,
track_numbers=construct_function_without_args ? (Number,) : (),
)
end

Expand Down Expand Up @@ -100,24 +83,16 @@ function make_mlir_fn(
)
end

if construct_function_without_args
fnbody = MLIR.IR.Block()
else
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
end
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)

@assert MLIR.IR._has_block()

result = MLIR.IR.block!(fnbody) do
for (i, arg) in enumerate(linear_args)
if construct_function_without_args
arg.mlir_data = args[i].mlir_data
else
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = transpose_val(raw_arg)
arg.mlir_data = row_maj_arg
end
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = transpose_val(raw_arg)
arg.mlir_data = row_maj_arg
end

# NOTE an `AbstractInterpreter` cannot process methods with more recent world-ages than it
Expand Down Expand Up @@ -153,11 +128,7 @@ function make_mlir_fn(
seen_results = OrderedIdDict()

traced_result = make_tracer(
seen_results,
result,
(:result,),
concretein ? TracedTrack : TracedSetPath;
track_numbers=construct_function_without_args ? (Number,) : (),
seen_results, result, (:result,), concretein ? TracedTrack : TracedSetPath
)

# marks buffers to be donated
Expand All @@ -171,7 +142,6 @@ function make_mlir_fn(

for (k, v) in seen_results
v isa TracedType || continue
(no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue
push!(linear_results, v)
end

Expand All @@ -180,16 +150,10 @@ function make_mlir_fn(
ret = MLIR.IR.block!(fnbody) do
vals = MLIR.IR.Value[]
for res in linear_results
if res isa MissingTracedValue
col_maj = broadcast_to_size(false, ()).mlir_data
elseif construct_function_without_args
col_maj = res.mlir_data
else
col_maj = transpose_val(res.mlir_data)
end
col_maj = transpose_val(res.mlir_data)
push!(vals, col_maj)
end
!no_args_in_result && @assert length(vals) == length(linear_results)
@assert length(vals) == length(linear_results)

dialect = getfield(MLIR.Dialects, return_dialect)
return dialect.return_(vals)
Expand Down

0 comments on commit 3006ae5

Please sign in to comment.