Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add var/std gradient wrt kw mean #478

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

chengchingwen
Copy link
Member

fix problem mentioned in #446

@CarloLucibello
Copy link
Member

looks fine and tests are exhaustive, merge?

@CarloLucibello
Copy link
Member

@chengchingwen is this giving warnings as you seem to mention in #446 (comment)?

@chengchingwen
Copy link
Member Author

@CarloLucibello yes, the warnings are still there.

@chengchingwen
Copy link
Member Author

if we want to avoid those warnings, we need to avoid using @adjoint for var and std by manually define the _pullback functions.

@DhairyaLGandhi
Copy link
Member

Defining internal methods like so seems sketchy. The adjoint definition from ChainRules as well as within zygote can handle it. The adjoint can accept kwargs

@chengchingwen
Copy link
Member Author

I don't really get what you mean. The current @adjoint return nothing for all the keyword arguments.

@DhairyaLGandhi
Copy link
Member

Let me rephrase that, we don't want to differentiate kwargs here rather let those be passed to the functions appropriately.

@chengchingwen
Copy link
Member Author

But the default std/var behavior only accept mean as keyword argument. If we want to make mean be differentiable as position argument, then we will result in a std/var api that only work when Zygote is imported. Otherwise, we might need to make a PR to Statistics for the std/var that accept mean as position argument.

@racinmat
Copy link
Contributor

Hi, is there any update on this? Is there a plan for that to be supported?

@mcabbott
Copy link
Member

mcabbott commented Feb 4, 2022

What's an example of when this would matter, and would give correct results? On simple things, and the example from here: #446 (comment) it gives zero. With @show inserted:

julia> gradient([1,2,3]) do x
         m = mean(x)
         std(x; mean=m)
       end
backmean = -0.0
([-0.5, 0.0, 0.5],)

julia> function normalise(x::AbstractArray; dims=1)
         μ′ = mean(x, dims = dims)
         σ′ = std(x, dims = dims, mean = μ′, corrected=false)
         return (x .- μ′) ./ σ′
       end
normalise (generic function with 1 method)

julia> gradient(x -> sum(sin, normalise(x)), [1,2,3.0])
backmean = [-0.0]
([-0.2697761903106471, 0.5395523806212941, -0.2697761903106471],)

Stumbled on this while thinking about JuliaDiff/ChainRules.jl#567

@chengchingwen
Copy link
Member Author

@mcabbott Yes, normalise was not an good example for this as the gradient of std(x, mean=μ) wrt μ is 0 if μ == mean(x). It would only matter when std(x, mean=μ) where μ != mean(x) and it's probably arguable whether you would ever do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants