Skip to content

Commit

Permalink
refactor: use Ops instead of direct stablehlo calls (#347)
Browse files Browse the repository at this point in the history
* refactor: use Ops instead of direct stablehlo calls

* revert: restore Base.conj

* fix: minor fixes to Ops

* feat: add convert dispatches

* revert: keep original `transpose_val` impl

* revert: keep control flow needs Operation
  • Loading branch information
avik-pal authored Dec 11, 2024
1 parent c4a9ae3 commit bcb6034
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 285 deletions.
47 changes: 15 additions & 32 deletions ext/ReactantAbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ReactantAbstractFFTsExt

using AbstractFFTs: AbstractFFTs
using Reactant: Reactant, MLIR, TracedRArray
using Reactant: Reactant, MLIR, Ops, TracedRArray

function check_contiguous_innermost_dims(dims, N)
@assert sort([dims...]) == [dims...] "un-sorted dims are not supported"
Expand Down Expand Up @@ -32,6 +32,7 @@ function compute_correct_pdims(x::AbstractArray, dims)
end

for op in (:rfft, :fft, :ifft)
mode = uppercase(string(op))
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
Expand All @@ -41,19 +42,20 @@ for op in (:rfft, :fft, :ifft)
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), nothing, 1)
return generalized_fft(x, $(mode), nothing, length(dims))
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims))
return generalized_fft(x, $(mode), nothing, length(dims))
end
end

for op in (:irfft,)
mode = uppercase(string(op))
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
Expand All @@ -63,49 +65,30 @@ for op in (:irfft,)
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), d, 1)
return generalized_fft(x, $(mode), d, length(dims))
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), d, length(dims))
return generalized_fft(x, $(mode), d, length(dims))
end
end

function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N}
@assert mode (:rfft, :irfft, :fft, :ifft)

x = permutedims(x, reverse(1:N))
fft_type_str = uppercase(string(mode))
fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str)

function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N}
if d === nothing
@assert mode (:rfft, :fft, :ifft)
if mode == :rfft
@assert T <: Real
rT = Complex{T}
res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1]
else
@assert T <: Complex
rT = T
res_size = [size(x)...]
end
fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)]
@assert mode ("RFFT", "FFT", "IFFT")
fft_length = [size(x, i) for i in 1:first_n]
else
@assert mode == :irfft
@assert T <: Complex
rT = real(T)
res_size = [size(x)[1:(end - 1)]..., d]
fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)]
@assert mode == "IRFFT"
fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n]
end

@assert 1 length(fft_length) 3 "stablehlo.fft only supports up to rank 3"
mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT))
op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type)
x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size))
x = permutedims(x, reverse(1:N))
reverse!(fft_length)
x = Ops.fft(x; type=mode, length=fft_length)
return permutedims(x, reverse(1:N))
end

Expand Down
36 changes: 3 additions & 33 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,7 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
[
MLIR.IR.result(
MLIR.Dialects.stablehlo.reshape(
v.mlir_data;
result_0=MLIR.IR.TensorType(
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
),
),
) for v in x.dval
];
dimension=Int64(0),
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
),
)
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
Expand All @@ -258,17 +248,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
[
MLIR.IR.result(
MLIR.Dialects.stablehlo.reshape(
v.mlir_data;
result_0=MLIR.IR.TensorType(
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
),
),
) for v in x.dval
];
dimension=Int64(0),
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
),
)
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
Expand Down Expand Up @@ -502,22 +482,12 @@ function overload_autodiff(
for i in 1:width
sz = size(a)
starts = Int64[i]
strides = Int64[1]
limits = Int64[i]
for v in sz
push!(starts, 0)
push!(limits, v)
push!(strides, 1)
end
sval = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
sval;
start_indices=starts,
limit_indices=limits,
stride_indices=strides,
),
1,
)
sval = Ops.slice(sval, starts, limits)
set!(dresult[i], path[2:end], sval)
end
end
Expand Down
62 changes: 57 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function constant(
end

function constant(x::ConcreteRArray; kwargs...)
return stablehlo.constant(convert(Array, x); kwargs...)
return stablehlo.constant(Base.convert(Array, x); kwargs...)
end

function constant(
Expand All @@ -42,7 +42,9 @@ function constant(
x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T}
output = mlir_type(TracedRArray{T,0}, ())
value = MLIR.IR.DenseElementsAttribute(fill(MLIR.IR.Attribute(convert(T, x)), output))
value = MLIR.IR.DenseElementsAttribute(
fill(MLIR.IR.Attribute(Base.convert(T, x)), output)
)
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
return TracedRNumber{T,N}((), res)
end
Expand Down Expand Up @@ -458,10 +460,11 @@ function fft(
Tout = Complex{T}
rsize = let rsize = collect(size(x))
rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1
Tuple(rsize)
end
elseif type == "IRFFT"
@assert T <: Complex
Tout = real(T)
Tout = Base.real(T)
rsize = let rsize = collect(size(x))
rsize[(end - Base.length(length) + 1):end] = length
Tuple(rsize)
Expand Down Expand Up @@ -514,7 +517,25 @@ function clamp(
return TracedRArray{T,N}((), res, size(x))
end

function clamp(min::T, x::TracedRArray{T,N}, max::T) where {T,N}
function clamp(
min::TracedRNumber{T},
x::TracedRNumber{T},
max::TracedRNumber{T};
location=mlir_stacktrace("clamp", @__FILE__, @__LINE__),
) where {T}
res = MLIR.IR.result(
stablehlo.clamp(
min.mlir_data,
x.mlir_data,
max.mlir_data;
result=mlir_type(TracedRArray{T,0}, ()),
location,
),
)
return TracedRNumber{T}((), res)
end

function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N}
return clamp(constant(min), x, constant(max))
end

Expand Down Expand Up @@ -1033,7 +1054,7 @@ function compare(
end

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.compare(
stablehlo.compare(
lhs.mlir_data,
rhs.mlir_data;
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
Expand All @@ -1048,6 +1069,37 @@ function compare(
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

# eltype conversion
function convert(
::Type{TracedRArray{T,N}},
x::TracedRArray;
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
) where {T,N}
@assert N == ndims(x)
return TracedRArray{T,N}(
(),
MLIR.IR.result(
stablehlo.convert(
x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location
),
),
size(x),
)
end

function convert(
::Type{TracedRNumber{T}},
x::TracedRNumber;
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
) where {T}
return TracedRNumber{T}(
(),
MLIR.IR.result(
stablehlo.convert(x.mlir_data; result=mlir_type(TracedRNumber{T}), location)
),
)
end

# Generate a unique name given a module hash and a function name.
function _hlo_call_name(orig_name, module_suffix)
return orig_name * "_hlo_call_" * module_suffix
Expand Down
Loading

0 comments on commit bcb6034

Please sign in to comment.