-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient of keyword argument? #446
Comments
No, there isn't currently, at least not directly. The best way to do this is instead to create a normal-argument wrapper around the function and differentiate that: _foo(a, b) = a*b
@adjoint _foo(a, b) = ...
foo(; a = 1, b = 2) = _foo(a, b) The setup might have to change a bit for a function built in to base, but should be doable; if not we can work out another option. |
Then we probably need to change the implementation of Is it possible to support keyword argument in the future? |
Is the issue that the gradient for If so that's a Zygote bug we should certainly fix; those things really need to work well out of the box, rather than relying on workarounds by users. |
I think so. Because Zygote can't get the gradient of the keyword argument right now, that gradient is missing. We might need to use/define a different set of statistics api that don't require kwarg (or make kwarg work in Zygote). |
probably need a custom definition for support function _pullback(__context__::AContext, ::Core.kwftype(typeof(var)), kw, typeof(var), xs) end |
Zygote can actually get gradients of kwargs, so it's not an issue there as such; the problem is in this definition for Directly overloading |
Modifying |
Maybe I misunderstand, but I do think the best way right now is to just overload |
@MikeInnes I think we still need a way to integrate the definition of gradient for kwarg into
|
Coming back to this after JuliaDiff/ChainRules.jl#569 and (half of) FluxML/Flux.jl#2005, I wonder if we couldn't special case One idea would be to prune |
Is there any ways to compute the backward gradient of a keyword argument? For example, the
var
function can pass a pre-computed mean value as a keyword argument, but right now it seems like we can't get the gradient of it. This will have an effect on something likeFlux.normalise
.The text was updated successfully, but these errors were encountered: