From 262d3bcfe63c21a1483001d57c00b317a13e4356 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Thu, 5 May 2022 16:32:09 -0400 Subject: [PATCH 1/2] add complex partial arithmetic, retrieving partials of complex wrapping dual --- src/dual.jl | 5 +++++ src/partials.jl | 16 ++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index c3dc48a0..cbe38153 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -112,6 +112,9 @@ end end end +@inline partials(z::Complex{TD}) where {TD <: Dual} = real(z).partials + im*imag(z).partials +@inline Base.@propagate_inbounds partials(z::Complex{TD}, i) where {TD} = real(z).partials[i] + im*imag(z).partials[i] + @inline npartials(::Dual{T,V,N}) where {T,V,N} = N @inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N @@ -123,6 +126,8 @@ end @inline valtype(::Type{V}) where {V} = V @inline valtype(::Dual{T,V,N}) where {T,V,N} = V @inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V +@inline valtype(::Complex{Dual{T,V,N}}) where {T,V,N} = complex(V) +@inline valtype(::Type{Complex{Dual{T,V,N}}}) where {T,V,N} = complex(V) @inline tagtype(::V) where {V} = Nothing @inline tagtype(::Type{V}) where {V} = Nothing diff --git a/src/partials.jl b/src/partials.jl index fce67b0a..12bf7ce1 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -81,7 +81,7 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti @inline Base.:+(a::Partials{N}, b::Partials{N}) where {N} = Partials(add_tuples(a.values, b.values)) @inline Base.:-(a::Partials{N}, b::Partials{N}) where {N} = Partials(sub_tuples(a.values, b.values)) @inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values)) -@inline Base.:*(x::Real, partials::Partials) = partials*x +@inline Base.:*(x::Union{Real, Complex}, partials::Partials) = partials*x @inline function _div_partials(a::Partials, b::Partials, aval, bval) return _mul_partials(a, b, inv(bval), -(aval / (bval*bval))) @@ -91,12 +91,12 @@ end #----------------------# if NANSAFE_MODE_ENABLED - @inline function Base.:*(partials::Partials, x::Real) + @inline function Base.:*(partials::Partials, x::Union{Real, Complex}) x = ifelse(!isfinite(x) && iszero(partials), one(x), x) return Partials(scale_tuple(partials.values, x)) end - @inline function Base.:/(partials::Partials, x::Real) + @inline function Base.:/(partials::Partials, x::Union{Real, Complex}) x = ifelse(x == zero(x) && iszero(partials), one(x), x) return Partials(div_tuple_by_scalar(partials.values, x)) end @@ -107,11 +107,11 @@ if NANSAFE_MODE_ENABLED return Partials(mul_tuples(a.values, b.values, x_a, x_b)) end else - @inline function Base.:*(partials::Partials, x::Real) + @inline function Base.:*(partials::Partials, x::Union{Real, Complex}) return Partials(scale_tuple(partials.values, x)) end - @inline function Base.:/(partials::Partials, x::Real) + @inline function Base.:/(partials::Partials, x::Union{Real, Complex}) return Partials(div_tuple_by_scalar(partials.values, x)) end @@ -132,10 +132,10 @@ end @inline Base.:-(a::Partials{N,A}, b::Partials{0,B}) where {N,A,B} = convert(Partials{N,promote_type(A,B)}, a) @inline Base.:-(partials::Partials{0,V}) where {V} = partials -@inline Base.:*(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) -@inline Base.:*(x::Real, partials::Partials{0,V}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) +@inline Base.:*(partials::Partials{0,V}, x::Union{Real, Complex}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) +@inline Base.:*(x::Union{Real, Complex}, partials::Partials{0,V}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) -@inline Base.:/(partials::Partials{0,V}, x::Real) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) +@inline Base.:/(partials::Partials{0,V}, x::Union{Real, Complex}) where {V} = Partials{0,promote_type(V,typeof(x))}(tuple()) @inline _mul_partials(a::Partials{0,A}, b::Partials{0,B}, afactor, bfactor) where {A,B} = Partials{0,promote_type(A,B)}(tuple()) @inline _mul_partials(a::Partials{0,A}, b::Partials{N,B}, afactor, bfactor) where {N,A,B} = bfactor * b From e3c490664f02d8a1696d8d06bfa27d9f30fcabec Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Thu, 5 May 2022 16:32:33 -0400 Subject: [PATCH 2/2] add complex eltypes to Partials tests, verify R -> C Jacobian works --- test/JacobianTest.jl | 37 +++++++++++++++++++++++++++++++++++++ test/PartialsTest.jl | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/test/JacobianTest.jl b/test/JacobianTest.jl index 9ea3b0a7..a7c83ca3 100644 --- a/test/JacobianTest.jl +++ b/test/JacobianTest.jl @@ -242,4 +242,41 @@ end @inferred ForwardDiff.jacobian(g!, [1.0], [0.0]) end +########################################## +# test specialized R^n -> C^m # +########################################## + +g(x) = [ + x[1]^2 + im*x[2], + im*x[1]^2 - x[2], + x[1] * x[2] * im + ] + +x_in = randn(2) + +J = [ + 2*x_in[1] im; + 2*im*x_in[1] -1; + x_in[2]*im x_in[1]*im + ] + +@testset "Real -> Complex Jacobian, No Chunking" begin + @test ForwardDiff.jacobian(g, x_in) == J + + J_out = zeros(complex(eltype(x_in)), (3, 2)) + ForwardDiff.jacobian!(J_out, g, x_in) + @test J_out == J +end + +for c in 1:2 + @testset "Chunked Real -> Complex Jacobian, Chunk = $c" begin + cfg = ForwardDiff.JacobianConfig(g, x_in, ForwardDiff.Chunk(c)) + @test ForwardDiff.jacobian(g, x_in, cfg) == J + + J_out = zeros(complex(eltype(x_in)), (3, 2)) + ForwardDiff.jacobian!(J_out, g, x_in, cfg) + @test J_out == J + end +end + end # module diff --git a/test/PartialsTest.jl b/test/PartialsTest.jl index 39fb05d7..157c226d 100644 --- a/test/PartialsTest.jl +++ b/test/PartialsTest.jl @@ -7,7 +7,7 @@ using ForwardDiff: Partials samerng() = MersenneTwister(1) -for N in (0, 3), T in (Int, Float32, Float64) +for N in (0, 3), T in (Int, Float32, Float64, ComplexF32, ComplexF64) println(" ...testing Partials{$N,$T}") VALUES = (rand(T,N)...,)