Skip to content

Commit

Permalink
feat: add convert dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 9, 2024
1 parent ef46a92 commit 2c69cf6
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@ ReactantCore.is_traced(::TracedRArray) = true

new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A))

TracedRArray{T,N}(rhs::TracedRArray{T,N}) where {T,N} = rhs
function TracedRArray{T,N}(rhs::TracedRArray{T0,N}) where {T,T0,N}
return Ops.convert(TracedRArray{T,N}, rhs)
end

function TracedRArray{T,N}(rhs::WrappedTracedRArray{T0,N}) where {T0,T,N}
return TracedRArray{T,N}(materialize_traced_array(rhs))
function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
@assert ndims(x) == N
if x isa TracedRArray
eltype(x) == T && return x
return Ops.convert(TracedRArray{T,N}, x)
end
x isa WrappedTracedRArray &&
return convert(TracedRArray{T,N}, materialize_traced_array(x))
return convert(TracedRArray{T,N}, Ops.constant(collect(x)))
end

TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} = Ops.constant(collect(rhs))
TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)

materialize_traced_array(x::TracedRArray) = x
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
Expand Down

0 comments on commit 2c69cf6

Please sign in to comment.