Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support short-circuiting expressions via trace macro #262

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ if no traced value is found inside the expression, then there is no overhead.
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
of the macro needs to be before the assignment and not before the `if`)
- Short circuiting `@trace a && b` or `@trace a || b`

## Special Considerations

Expand Down Expand Up @@ -100,15 +101,51 @@ end
"""
macro trace(expr)
expr = macroexpand(__module__, expr)
if expr.head == :(=)
if expr.args[2] isa Expr && expr.args[2].head == :if
if Meta.isexpr(expr, :(&&), 2) || Meta.isexpr(expr, :(||), 2)
return esc(trace_short_circuit(__module__, expr))
end
if Meta.isexpr(expr, :(=))
if Meta.isexpr(expr.args[2], :if)
return esc(trace_if_with_returns(__module__, expr))
end
end
expr.head == :if && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
end

function trace_short_circuit(mod, expr)
if_expr, lhs, varname = generate_if_from_short_circuit(mod, expr)
new_expr = trace_if(mod, if_expr)
return quote
$(varname) = $(lhs)
$(new_expr)
end
end

function generate_if_from_short_circuit(mod, expr; depth=0)
varname = gensym(:short_circuit_result)
lhs = expr.args[1]
rhs = expr.args[2]
if Meta.isexpr(rhs, :(&&), 2) || Meta.isexpr(rhs, :(||), 2)
rhs = generate_if_from_short_circuit(mod, rhs; depth=depth + 1)
end
if Meta.isexpr(expr, :(&&), 2)
expr = :(
if $(varname)
$(rhs)
end
)
else
expr = :(
if !$(varname)
$(rhs)
end
)
end
depth == 0 && return expr, lhs, varname
return :($(varname) = $(lhs); $(expr))
end

# ... = if ... style expressions
function trace_if_with_returns(mod, expr)
new_expr, _, all_check_vars = trace_if(
Expand All @@ -133,7 +170,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
counter = 0
expr = MacroTools.prewalk(expr) do x
counter += 1
if x isa Expr && x.head == :if && counter > 1
if Meta.isexpr(x, :if) && counter > 1
ex_new, dv, _ = trace_if(mod, x; store_last_line, depth=depth + 1)
append!(discard_vars_from_expansion, dv)
return ex_new
Expand Down Expand Up @@ -164,18 +201,24 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
all_true_branch_vars = true_branch_input_list ∪ true_branch_assignments
true_branch_fn_name = gensym(:true_branch)

else_block, discard_vars, _ = if length(expr.args) == 3
if expr.args[3].head != :elseif
expr.args[3], [], nothing
else_block, discard_vars, _, fake_assignments = if length(expr.args) == 3
if Meta.isexpr(expr.args[3], :elseif)
expr.args[3], [], nothing, :()
else
trace_if(mod, expr.args[3]; store_last_line, depth=depth + 1)
(trace_if(mod, expr.args[3]; store_last_line, depth=depth + 1)..., :())
end
elseif length(expr.args) == 2
tmp_expr = []
extra_assignments = []
for var in true_branch_assignments
push!(tmp_expr, :($(var) = $(var)))
push!(extra_assignments, :(
if !isdefined($(mod), $(Meta.quot(var)))
$(var) = nothing
end
))
end
Expr(:block, tmp_expr...), [], nothing
Expr(:block, tmp_expr...), [], nothing, Expr(:block, extra_assignments...)
else
dump(expr)
error("This shouldn't happen")
Expand Down Expand Up @@ -243,6 +286,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn))

reactant_code_block = quote
$(fake_assignments)
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
Expand Down Expand Up @@ -296,10 +340,8 @@ end

function error_if_return(expr)
return MacroTools.postwalk(expr) do x
if x isa Expr && x.head == :return
error("Cannot use @trace on a block that contains a return statement")
end
return x
Meta.isexpr(x, :return) || return x
error("Cannot use @trace on a block that contains a return statement")
end
end

Expand Down
Loading