Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient of keyword argument? #446

Open
chengchingwen opened this issue Jan 6, 2020 · 10 comments
Open

gradient of keyword argument? #446

chengchingwen opened this issue Jan 6, 2020 · 10 comments
Labels

Comments

@chengchingwen
Copy link
Member

Is there any ways to compute the backward gradient of a keyword argument? For example, the var function can pass a pre-computed mean value as a keyword argument, but right now it seems like we can't get the gradient of it. This will have an effect on something like Flux.normalise.

@MikeInnes
Copy link
Member

No, there isn't currently, at least not directly. The best way to do this is instead to create a normal-argument wrapper around the function and differentiate that:

_foo(a, b) = a*b

@adjoint _foo(a, b) = ...

foo(; a = 1, b = 2) = _foo(a, b)

The setup might have to change a bit for a function built in to base, but should be doable; if not we can work out another option.

@chengchingwen
Copy link
Member Author

Then we probably need to change the implementation of Flux.normalise
https://github.com/FluxML/Flux.jl/blob/e92da0cf850a982c425b83c92d6274174e52b02c/src/layers/stateless.jl#L82-L86
or maybe add some other function for computing mean and variance.

Is it possible to support keyword argument in the future?

@MikeInnes
Copy link
Member

Is the issue that the gradient for normalise ends up incorrect because the gradient for std is incorrect (with respect to the kwarg mean)?

If so that's a Zygote bug we should certainly fix; those things really need to work well out of the box, rather than relying on workarounds by users.

@chengchingwen
Copy link
Member Author

I think so. Because Zygote can't get the gradient of the keyword argument right now, that gradient is missing. We might need to use/define a different set of statistics api that don't require kwarg (or make kwarg work in Zygote).

@chengchingwen
Copy link
Member Author

probably need a custom definition for support mean kwarg in Zygote (but won't work for other kwarg functions)

function _pullback(__context__::AContext, ::Core.kwftype(typeof(var)), kw, typeof(var), xs) end

@MikeInnes
Copy link
Member

Zygote can actually get gradients of kwargs, so it's not an issue there as such; the problem is in this definition for var (along with std below) which assumes the default definition of mean will be used and calculates a gradient only for xs consistent with that. Actually telling Zygote about the gradient of mean is one issue, and modifying _backvar to calculate that gradient is another.

Directly overloading _pullback is probably the right way to solve this for now; it's ugly but it'll let you do whatever you want, and we can add a more convenient way to specify kwarg gradients later.

@chengchingwen
Copy link
Member Author

Modifying _backvar shouldn't be too hard since the formulae are quite alike, but right now I have no idea how can I tell Zygote about it without overloading _pullbackdirectly. Do you have any idea in mind for that (maybe another macro or prefix/suffix in @adjoint)?

@MikeInnes
Copy link
Member

Maybe I misunderstand, but I do think the best way right now is to just overload _pullback directly, at least in the short term. Happy to help with how to do that if needed.

@chengchingwen
Copy link
Member Author

@MikeInnes I think we still need a way to integrate the definition of gradient for kwarg into @adjoint, because by simply overloading _pullback will get a warning

WARNING: Method definition _pullback(ZygoteRules.AContext, getfield(Statistics, Symbol("#kw##var")), Any, typeof(Statistics.var
), AbstractArray{T, N} where N where T) in module Zygote at /home/peter/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:53 ove
rwritten at /home/peter/peter/fork/Zygote.jl/src/lib/array.jl:242.
  ** incremental compilation may be fatally broken for this module **

@ToucheSir
Copy link
Member

Coming back to this after JuliaDiff/ChainRules.jl#569 and (half of) FluxML/Flux.jl#2005, I wonder if we couldn't special case _pullback(::Context, ::Core.kwftype, ::Tuple{NamedTuple, ...}).

One idea would be to prune haskey-related blocks we know are dead because the relevant keyword either exists or doesn't in the args NamedTuple. We can determine this at compile time based on the arg types, so this pruning pass should be possible on any of the IRs Zygote interacts with. At the present moment, I feel the IRTools level seems most promising.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants