From bcb6034589a0360a4c914e5ed5d2f18807034e6c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 12:30:32 +0530 Subject: [PATCH] refactor: use Ops instead of direct stablehlo calls (#347) * refactor: use Ops instead of direct stablehlo calls * revert: restore Base.conj * fix: minor fixes to Ops * feat: add convert dispatches * revert: keep original `transpose_val` impl * revert: keep control flow needs Operation --- ext/ReactantAbstractFFTsExt.jl | 47 ++++------- src/Interpreter.jl | 36 +-------- src/Ops.jl | 62 +++++++++++++-- src/TracedRArray.jl | 141 +++++++-------------------------- src/TracedRNumber.jl | 123 +++++----------------------- 5 files changed, 124 insertions(+), 285 deletions(-) diff --git a/ext/ReactantAbstractFFTsExt.jl b/ext/ReactantAbstractFFTsExt.jl index 32a92fc1e..52f504f4a 100644 --- a/ext/ReactantAbstractFFTsExt.jl +++ b/ext/ReactantAbstractFFTsExt.jl @@ -1,7 +1,7 @@ module ReactantAbstractFFTsExt using AbstractFFTs: AbstractFFTs -using Reactant: Reactant, MLIR, TracedRArray +using Reactant: Reactant, MLIR, Ops, TracedRArray function check_contiguous_innermost_dims(dims, N) @assert sort([dims...]) == [dims...] "un-sorted dims are not supported" @@ -32,6 +32,7 @@ function compute_correct_pdims(x::AbstractArray, dims) end for op in (:rfft, :fft, :ifft) + mode = uppercase(string(op)) @eval function AbstractFFTs.$(op)(x::TracedRArray, dims) @assert maximum(dims) ≤ ndims(x) "dims out of range" if dims isa Integer @@ -41,7 +42,7 @@ for op in (:rfft, :fft, :ifft) AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims) ) end - return generalized_fft(x, $(Meta.quot(op)), nothing, 1) + return generalized_fft(x, $(mode), nothing, length(dims)) end if !check_contiguous_innermost_dims(dims, ndims(x)) pdims = compute_correct_pdims(x, dims) @@ -49,11 +50,12 @@ for op in (:rfft, :fft, :ifft) AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) ) end - return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims)) + return generalized_fft(x, $(mode), nothing, length(dims)) end end for op in (:irfft,) + mode = uppercase(string(op)) @eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims) @assert maximum(dims) ≤ ndims(x) "dims out of range" if dims isa Integer @@ -63,7 +65,7 @@ for op in (:irfft,) AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims) ) end - return generalized_fft(x, $(Meta.quot(op)), d, 1) + return generalized_fft(x, $(mode), d, length(dims)) end if !check_contiguous_innermost_dims(dims, ndims(x)) pdims = compute_correct_pdims(x, dims) @@ -71,41 +73,22 @@ for op in (:irfft,) AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims) ) end - return generalized_fft(x, $(Meta.quot(op)), d, length(dims)) + return generalized_fft(x, $(mode), d, length(dims)) end end -function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N} - @assert mode ∈ (:rfft, :irfft, :fft, :ifft) - - x = permutedims(x, reverse(1:N)) - fft_type_str = uppercase(string(mode)) - fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str) - +function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N} if d === nothing - @assert mode ∈ (:rfft, :fft, :ifft) - if mode == :rfft - @assert T <: Real - rT = Complex{T} - res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1] - else - @assert T <: Complex - rT = T - res_size = [size(x)...] - end - fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)] + @assert mode ∈ ("RFFT", "FFT", "IFFT") + fft_length = [size(x, i) for i in 1:first_n] else - @assert mode == :irfft - @assert T <: Complex - rT = real(T) - res_size = [size(x)[1:(end - 1)]..., d] - fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)] + @assert mode == "IRFFT" + fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n] end - @assert 1 ≤ length(fft_length) ≤ 3 "stablehlo.fft only supports up to rank 3" - mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT)) - op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type) - x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size)) + x = permutedims(x, reverse(1:N)) + reverse!(fft_length) + x = Ops.fft(x; type=mode, length=fft_length) return permutedims(x, reverse(1:N)) end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 2efb53792..8a039e17a 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -233,17 +233,7 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) predims = size(x.val) cval = MLIR.IR.result( MLIR.Dialects.stablehlo.concatenate( - [ - MLIR.IR.result( - MLIR.Dialects.stablehlo.reshape( - v.mlir_data; - result_0=MLIR.IR.TensorType( - Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data)) - ), - ), - ) for v in x.dval - ]; - dimension=Int64(0), + [Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) @@ -258,17 +248,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) predims = size(x.val) cval = MLIR.IR.result( MLIR.Dialects.stablehlo.concatenate( - [ - MLIR.IR.result( - MLIR.Dialects.stablehlo.reshape( - v.mlir_data; - result_0=MLIR.IR.TensorType( - Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data)) - ), - ), - ) for v in x.dval - ]; - dimension=Int64(0), + [Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) @@ -502,22 +482,12 @@ function overload_autodiff( for i in 1:width sz = size(a) starts = Int64[i] - strides = Int64[1] limits = Int64[i] for v in sz push!(starts, 0) push!(limits, v) - push!(strides, 1) end - sval = MLIR.IR.result( - MLIR.Dialects.stablehlo.slice( - sval; - start_indices=starts, - limit_indices=limits, - stride_indices=strides, - ), - 1, - ) + sval = Ops.slice(sval, starts, limits) set!(dresult[i], path[2:end], sval) end end diff --git a/src/Ops.jl b/src/Ops.jl index 2f1c7f6eb..013e0dbc8 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -28,7 +28,7 @@ function constant( end function constant(x::ConcreteRArray; kwargs...) - return stablehlo.constant(convert(Array, x); kwargs...) + return stablehlo.constant(Base.convert(Array, x); kwargs...) end function constant( @@ -42,7 +42,9 @@ function constant( x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T} output = mlir_type(TracedRArray{T,0}, ()) - value = MLIR.IR.DenseElementsAttribute(fill(MLIR.IR.Attribute(convert(T, x)), output)) + value = MLIR.IR.DenseElementsAttribute( + fill(MLIR.IR.Attribute(Base.convert(T, x)), output) + ) res = MLIR.IR.result(stablehlo.constant(; output, value, location)) return TracedRNumber{T,N}((), res) end @@ -458,10 +460,11 @@ function fft( Tout = Complex{T} rsize = let rsize = collect(size(x)) rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1 + Tuple(rsize) end elseif type == "IRFFT" @assert T <: Complex - Tout = real(T) + Tout = Base.real(T) rsize = let rsize = collect(size(x)) rsize[(end - Base.length(length) + 1):end] = length Tuple(rsize) @@ -514,7 +517,25 @@ function clamp( return TracedRArray{T,N}((), res, size(x)) end -function clamp(min::T, x::TracedRArray{T,N}, max::T) where {T,N} +function clamp( + min::TracedRNumber{T}, + x::TracedRNumber{T}, + max::TracedRNumber{T}; + location=mlir_stacktrace("clamp", @__FILE__, @__LINE__), +) where {T} + res = MLIR.IR.result( + stablehlo.clamp( + min.mlir_data, + x.mlir_data, + max.mlir_data; + result=mlir_type(TracedRArray{T,0}, ()), + location, + ), + ) + return TracedRNumber{T}((), res) +end + +function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N} return clamp(constant(min), x, constant(max)) end @@ -1033,7 +1054,7 @@ function compare( end res = MLIR.IR.result( - MLIR.Dialects.stablehlo.compare( + stablehlo.compare( lhs.mlir_data, rhs.mlir_data; comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( @@ -1048,6 +1069,37 @@ function compare( return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) end +# eltype conversion +function convert( + ::Type{TracedRArray{T,N}}, + x::TracedRArray; + location=mlir_stacktrace("convert", @__FILE__, @__LINE__), +) where {T,N} + @assert N == ndims(x) + return TracedRArray{T,N}( + (), + MLIR.IR.result( + stablehlo.convert( + x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location + ), + ), + size(x), + ) +end + +function convert( + ::Type{TracedRNumber{T}}, + x::TracedRNumber; + location=mlir_stacktrace("convert", @__FILE__, @__LINE__), +) where {T} + return TracedRNumber{T}( + (), + MLIR.IR.result( + stablehlo.convert(x.mlir_data; result=mlir_type(TracedRNumber{T}), location) + ), + ) +end + # Generate a unique name given a module hash and a function name. function _hlo_call_name(orig_name, module_suffix) return orig_name * "_hlo_call_" * module_suffix diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0e9bf6f77..0d73b1093 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -36,35 +36,18 @@ ReactantCore.is_traced(::TracedRArray) = true new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) -function TracedRArray{T,N}(rhs::TracedRArray{T0,N}) where {T,T0,N} - if T == T0 - return rhs - else - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRArray{T,N}, size(rhs)) - ), - 1, - ), - size(rhs), - ) +function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} + @assert ndims(x) == N + if x isa TracedRArray + eltype(x) == T && return x + return Ops.convert(TracedRArray{T,N}, x) end + x isa WrappedTracedRArray && + return convert(TracedRArray{T,N}, materialize_traced_array(x)) + return convert(TracedRArray{T,N}, Ops.constant(collect(x))) end -function TracedRArray{T,N}(rhs::WrappedTracedRArray{T0,N}) where {T0,T,N} - return TracedRArray{T,N}(materialize_traced_array(rhs)) -end - -function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} - attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) - return TracedRArray{T,N}( - TracedRArray{T0,length(size(rhs))}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs) - ), - ) -end +TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] @@ -164,7 +147,6 @@ function Base.getindex( ), 1, ) - return TracedRNumber{T}((), res2) end @@ -254,9 +236,7 @@ Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, # TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`? function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N} - attr = MLIR.IR.DenseElementsAttribute(zeros(T, dims)) - res = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - return TracedRArray{T,N}((), res, dims) + return Ops.constant(zeros(T, dims)) end function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}} @@ -266,69 +246,23 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC end function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.transpose( - get_mlir_data(A); - permutation=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in perm]), - ), - 1, - ), - Tuple(size(A, i) for i in perm), - ) + return Ops.transpose(materialize_traced_array(A), Int64[perm...]) end -Base.conj(A::TracedRArray) = A -function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N} - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.chlo.conj( - A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A)) - ), - 1, - ), - size(A), - ) -end +Base.conj(A::AnyTracedRArray) = A +Base.conj(A::AnyTracedRArray{<:Complex}) = Ops.conj(materialize_traced_array(A)) -Base.conj!(A::TracedRArray) = A -function Base.conj!(A::TracedRArray{T,N}) where {T<:Complex,N} - A.mlir_data = MLIR.IR.result( - MLIR.Dialects.chlo.conj(A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))), - 1, - ) +Base.conj!(A::AnyTracedRArray) = A +function Base.conj!(A::AnyTracedRArray{<:Complex}) + set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) return A end -Base.real(A::TracedRArray) = A -function Base.real(A::TracedRArray{Complex{T},N}) where {T,N} - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.real( - A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A)) - ), - 1, - ), - size(A), - ) -end +Base.real(A::AnyTracedRArray) = A +Base.real(A::AnyTracedRArray{<:Complex}) = Ops.real(materialize_traced_array(A)) -Base.imag(A::TracedRArray) = zero(A) -function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N} - return TracedRArray{T,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.imag( - A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A)) - ), - 1, - ), - size(A), - ) -end +Base.imag(A::AnyTracedRArray) = zero(A) +Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A)) promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) @@ -521,13 +455,7 @@ function Base.mapreduce( redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) if dims != (:) - red = MLIR.IR.result( - MLIR.Dialects.stablehlo.reshape( - red; result_0=MLIR.IR.TensorType(toonedims, eltype(MLIR.IR.type(red))) - ), - 1, - ) - red = TracedRArray{redT,length(toonedims)}((), red, (toonedims...,)) + red = Ops.reshape(TracedRArray(red), toonedims...) else if length(outdims) == 0 red = TracedRNumber{redT}((), red) @@ -633,27 +561,14 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end -function broadcast_to_size(arg::AbstractArray, rsize) - attr = MLIR.IR.DenseElementsAttribute(arg) - len = ndims(arg) - @assert typeof(len) == Int - arg = TracedRArray{eltype(arg),len}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(arg) - ) - return broadcast_to_size(arg, rsize) -end +broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) function broadcast_to_size(arg::Base.RefValue, rsize) # XXX: don't we want to expand here to rsize? return arg end -function broadcast_to_size(arg::T, rsize) where {T<:Number} - attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize))) - return TracedRArray{T,length(rsize)}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize - ) -end +broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) function broadcast_to_size(arg::TracedRNumber, rsize) length(rsize) == 0 && return arg @@ -806,12 +721,12 @@ function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} end for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) - @eval function Base.clamp!(x::TracedRArray{T}, min::$(minT), max::$(maxT)) where {T} - y = clamp.(x, min, max) - x.mlir_data = y.mlir_data + @eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT)) + y = Ops.clamp(min, materialize_traced_array(x), max) + set_mlir_data!(x, y.mlir_data) return x end end -Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) -Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) +Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) +Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index ebe733ce6..dc7a7ec2a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -66,43 +66,15 @@ function TracedRNumber{T}(x::Number) where {T} end function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} - if isa(rhs, TracedRNumber) + if rhs isa TracedRNumber rhs isa TracedRNumber{T} && return rhs - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRNumber{T}) - ), - 1, - ), - ) + return Ops.convert(TracedRNumber{T}, rhs) end - if isa(rhs, TracedRArray{<:Any,0}) - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRNumber{T}) - ), - 1, - ), - ) + if rhs isa TracedRArray{<:Any,0} + return promote_to(TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)) end - if isa(rhs, Number) - attr = MLIR.IR.DenseElementsAttribute(fill(T(rhs))) - return TracedRNumber{T}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - ) - end - T0 = eltype(rhs) - attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) - return promote_to( - TracedRNumber{T}, - TracedRNumber{T0}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - ), - ) + rhs isa Number && return promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) + return promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) end promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) @@ -119,27 +91,14 @@ for (jlop, hloop) in ( @eval function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) + return Ops.$(hloop)(lhs, rhs) end end function Base.div( @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) ) where {T<:Integer} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.divide( - lhs.mlir_data, promote_to(TracedRNumber{T}, rhs).mlir_data - ), - 1, - ), - ) + return Ops.divide(lhs, promote_to(TracedRNumber{T}, rhs)) end for (jlop, hloop, hlocomp) in ( @@ -207,22 +166,19 @@ function Base.ifelse( end for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) + T = promote_type(T1, T2) @eval begin function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) - return TracedRNumber{promote_type(eltype(x), eltype(y))}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) + return Ops.and( + promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) ) end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) - return TracedRNumber{promote_type(eltype(x), eltype(y))}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) - ) - end - function Base.:!(x::TracedRNumber{<:$(T1)}) - return TracedRNumber{eltype(x)}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) + return Ops.or( + promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) ) end + Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) end end @@ -241,64 +197,27 @@ for (jlop, hloop) in ( (:(Base.FastMath.tanh_fast), :tanh), (:(Base.exp), :exponential), (:(Base.FastMath.exp_fast), :exponential), + (:(Base.expm1), :exponential_minus_one), (:(Base.log), :log), + (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), (:(Base.ceil), :ceil), (:(Base.floor), :floor), ) - @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} - OutTy = $(hloop === :abs) ? real(T) : T - return TracedRNumber{OutTy}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) - ) - end + @eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs) end Base.conj(x::TracedRNumber) = x -function Base.conj(x::TracedRNumber{T}) where {T<:Complex} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.chlo.conj(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1 - ), - ) -end +Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x) Base.real(x::TracedRNumber) = x -function Base.real(x::TracedRNumber{Complex{T}}) where {T} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.real(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1 - ), - ) -end +Base.real(x::TracedRNumber{<:Complex}) = Ops.real(x) Base.imag(x::TracedRNumber) = zero(x) -function Base.imag(x::TracedRNumber{Complex{T}}) where {T} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.imag(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1 - ), - ) -end - -Base.abs2(x::TracedRNumber) = abs(x)^2 - -Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) +Base.imag(x::TracedRNumber{<:Complex}) = Ops.imag(x) for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) - @eval function Base.clamp(x::TracedRNumber{T}, min::$(minT), max::$(maxT)) where {T} - min = promote_to(TracedRNumber{T}, min) - max = promote_to(TracedRNumber{T}, max) - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.clamp(min.mlir_data, x.mlir_data, max.mlir_data), 1 - ), - ) - end + @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end struct TypeCast{T<:ReactantPrimitive} <: Function end