Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit derivatives for complex analytic functions #727

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

clguillot
Copy link

Hi everyone,

As documented in #726, #514, #653 ForwardDiff encounters some difficulty when it comes to differentiating complex functions. The issue is similar to the one in #481 . This is a very serious issue which has been encountered by several people, including myself, and can be quite difficult to identify the first time, especially when using ForwardDiff as a "black box". Julia and ForwardDiff are widely used by mathematicians in the field of quantum chemistry, who manipulate complex number daily, and encountering this issue can result in a huge loss of time.
Up to a few exceptions (mainly the functions abs and abs2), all the rules defined in DiffRules can be extended to the complex domain with no modification. I worked a bit on the issue and came up with a start of fix by modifying unary_dual_definition as follows:
https://github.com/clguillot/ForwardDiff.jl/blob/114cfe90755df8591488c7d71bd3109be5325fb9/src/dual.jl#L234-L265
I simply define a version of the derivative for Complex{Dual} defined with the expression found in DiffRule and returning a new Complex{Dual}.
With this fix, I get the correct result when computing the order 2 derivative of exp(ix):

julia> f(x) = exp(1im*x);
julia> df(x) = ForwardDiff.derivative(f, x);
julia> ForwardDiff.derivative(df, 0.0)
-1.0 + 0.0im

Without the fix, the same computation (unless the modification in #481 is implemented) returns 0.0 + 0.0im, which is obviously wrong.
I also implemented sin and cos for Complex by hand
https://github.com/clguillot/ForwardDiff.jl/blob/114cfe90755df8591488c7d71bd3109be5325fb9/src/dual.jl#L747-L771
but avoiding sincos since I was lazy.

I believe having explicit derivatives in those cases will mostly free ForwardDiff from having to worry about how the functions are implemented in the libraries, since it never needs to actually go through this code with a Dual type. Moreover, I don't think it to be too harmful for the performances either.

One issue that I see with this pull request is the manual exclusion of some function. It would probably be more elegant to modify DiffRules to indicate which function can see its derivative extended to the complex domain, for example by defining a macro @define_analytic_diffrule which would make a call to @define_diffrules and put the function into a some kind of list indicating that it is analytic. Until something like this can be pulled up, the code above should at least provide a basis that returns the right answer in most cases.
It would also be nice to provide a similar fix for functions of several variables, but all in good time.

Copy link

codecov bot commented Dec 4, 2024

Codecov Report

Attention: Patch coverage is 21.42857% with 33 lines in your changes missing coverage. Please review.

Project coverage is 83.90%. Comparing base (c310fb5) to head (e4aa3cf).
Report is 12 commits behind head on master.

Files with missing lines Patch % Lines
src/dual.jl 21.42% 33 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #727      +/-   ##
==========================================
- Coverage   89.57%   83.90%   -5.68%     
==========================================
  Files          11       10       -1     
  Lines         969      963       -6     
==========================================
- Hits          868      808      -60     
- Misses        101      155      +54     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mcabbott
Copy link
Member

mcabbott commented Dec 4, 2024

Am I correct to think that #514 and #653 are fixed on master, as is #486 (similar but not mentioned above)?

And are there other cases where Complex{Dual} goes wrong, which aren't fixed?

Certainly the tests in this PR pass:

julia> @testset "analytic functions" begin
           dexp(x) = ForwardDiff.derivative(y -> exp(complex(0, y)), x)
           @test ForwardDiff.derivative(dexp, 0.0)  -1
           @test ForwardDiff.derivative(x -> exp(1im*x), 0.7)  im * cis(0.7)
           @test ForwardDiff.derivative(x -> sqrt(im + (1+im) * x), 1.23)  (1+im) / (2 * sqrt(im + (1+im)*1.23))
       end
Test Summary:      | Pass  Total  Time
analytic functions |    3      3  0.5s
Test.DefaultTestSet("analytic functions", Any[], 3, false, false, true, 1.733352822760275e9, 1.733352823277592e9, false, "REPL[29]")

That's not to say we shouldn't add complex methods, but if they don't have correctness justifications, then they need something else. Are these more accurate, or faster, or do they handle some edge cases better?

@antoine-levitt antoine-levitt mentioned this pull request Dec 5, 2024
@clguillot
Copy link
Author

The problem is effectively fixed on master, but not in the release. From a user point of view, it is equivalent to an unresolved issue.

I see several benefits from handling the complex derivatives this way:

  • The question of having or not ForwardDiff compatible code inside complex elementary functions would no longer be a preoccupation as long as a correct diffrule is defined.
  • This would fix the issue in a way that seems less breaking than Change == to ignore measure-zero branches #481 , since I don't see why someone would rely on the fact that we actually go through the code of elementary functions with a Dual type.
  • Regarding the performance, that's a good points, it would probably require some careful testing. I am not expecting too much improvement on this side anyway.
  • Another benefit worth mentioning concerns the support of functions defined in SpecialFunctions. Since a lot of those are defined only for ComplexF64, the current version of ForwardDiff will throw an error when calling it with Complex{Dual}. Using the diffrule instead, we get the correct answer.
julia> (1+im) * digamma((1+im)*1.2)
-0.7029212687502131 + 1.3421243356860844im
julia> ForwardDiff.derivative(x -> loggamma((1+im)*x), 1.2) # With the diffrule
-0.7029212687502131 + 1.3421243356860844im
julia> ForwardDiff.derivative(x -> loggamma((1+im)*x), 1.2) # Without the diffrule
ERROR: MethodError: no method matching _loggamma(::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}})
The function `_loggamma` exists, but no method is defined for this combination of argument types.

In any case, if this modification does not seem appropriate, a corrected version of exp(Complex{Dual}) should be defined as a temporary fix in the release, until the discussions about v1.0 are concluded. If you prefer this way I can make a pull request in this sense, and maybe leave the question of complex derivatives for another time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants