Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Dec 11, 2024
1 parent 17c2f72 commit 4807a79
Showing 1 changed file with 83 additions and 58 deletions.
141 changes: 83 additions & 58 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,17 @@ function call_with_reactant_generator(

match = matches[1]::Core.MethodMatch
# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)

mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
match.method,
match.spec_types,
match.sparams,
)

result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp)
frame = Core.Compiler.InferenceState(result, :local, interp) #=cache_mode=#
@assert frame !== nothing
Core.Compiler.typeinf(interp, frame)
@static if VERSION >= v"1.11"
Expand All @@ -400,45 +406,45 @@ function call_with_reactant_generator(
# rt = frame.result.result::Core.Compiler.Const
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
#else
opt = Core.Compiler.OptimizationState(frame, interp)

caller = frame.result
@static if VERSION < v"1.11-"
ir = Core.Compiler.run_passes(opt.src, opt, caller)
else
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller)
end

# Rewrite type unstable calls to recurse into call_with_reactant to ensure
# they continue to use our interpreter. Reset the derived return type
# to Any if our interpreter would change the return type of any result.
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
any_changed = false
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
end
end
Core.Compiler.finish(interp, opt, ir, caller)
src = Core.Compiler.ir_to_codeinf!(opt)
# Julia hits various internal errors trying to re-perform type inference
# on type infered code (that we undo inference of), if there is no type unstable
# code to be rewritten, just use the default methodinstance (still using our methodtable),
# to improve compatibility as these bugs are fixed upstream.
if !any_changed
src = Core.Compiler.retrieve_code_info(mi, world)
end
opt = Core.Compiler.OptimizationState(frame, interp)

caller = frame.result
@static if VERSION < v"1.11-"
ir = Core.Compiler.run_passes(opt.src, opt, caller)
else
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller)
end

# Rewrite type unstable calls to recurse into call_with_reactant to ensure
# they continue to use our interpreter. Reset the derived return type
# to Any if our interpreter would change the return type of any result.
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
any_changed = false
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
end
end
Core.Compiler.finish(interp, opt, ir, caller)
src = Core.Compiler.ir_to_codeinf!(opt)

# Julia hits various internal errors trying to re-perform type inference
# on type infered code (that we undo inference of), if there is no type unstable
# code to be rewritten, just use the default methodinstance (still using our methodtable),
# to improve compatibility as these bugs are fixed upstream.
if !any_changed
src = Core.Compiler.retrieve_code_info(mi, world)
end

# prepare a new code info
code_info = copy(src)
Expand All @@ -454,7 +460,9 @@ function call_with_reactant_generator(

# Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant,
# and the REDUB_ARGUMENTS_NAME tuple of input arguments
code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...]
code_info.slotnames = Any[
:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...
]
code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...]
n_prepended_slots = 2
overdub_args_slot = Core.SlotNumber(n_prepended_slots)
Expand All @@ -464,7 +472,6 @@ function call_with_reactant_generator(
# the end of the pass, we'll reset `code_info` fields accordingly.
overdubbed_code = Any[]
overdubbed_codelocs = Int32[]

# Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention
# required by the base method.

Expand All @@ -481,14 +488,16 @@ function call_with_reactant_generator(
offset += 1
end
slot = i + n_prepended_slots
actual_argument = Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset)
actual_argument = Expr(
:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset
)
push!(overdubbed_code, :($(Core.SlotNumber(slot)) = $actual_argument))
push!(overdubbed_codelocs, code_info.codelocs[1])
code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set
offset += 1
#push!(overdubbed_code, actual_argument)
push!(fn_args, Core.SSAValue(length(overdubbed_code)))

#push!(overdubbed_code, actual_argument)
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
end

# If `method` is a varargs method, we have to restructure the original method call's
Expand All @@ -497,26 +506,42 @@ function call_with_reactant_generator(
if !isempty(overdubbed_code)
# remove the final slot reassignment leftover from the previous destructuring
pop!(overdubbed_code)
pop!(overdubbed_codelocs)
pop!(fn_args)
pop!(overdubbed_codelocs)
pop!(fn_args)
end
trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple))
for i in n_method_args:n_actual_args
push!(overdubbed_code, Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1))
push!(
overdubbed_code,
Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1),
)
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code)))
offset += 1
end
push!(overdubbed_code, Expr(:(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments))
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
push!(
overdubbed_code,
Expr(
:(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments
),
)
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
end

# substitute static parameters, offset slot numbers by number of added slots, and
# offset statement indices by the number of additional statements

arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...],
n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate)
arg_partially_inline!(
code_info.code,
fn_args,
method.sig,
Any[static_params...],
n_prepended_slots,
n_prepended_slots,
length(overdubbed_code),
:propagate,
)

append!(overdubbed_code, code_info.code)
append!(overdubbed_codelocs, code_info.codelocs)
Expand All @@ -537,7 +562,7 @@ end

@eval function call_with_reactant($OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, call_with_reactant_generator))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end

function make_mlir_fn(
Expand Down

0 comments on commit 4807a79

Please sign in to comment.