Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 5, 2024
1 parent ca51a77 commit 9740c38
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ const FluxMetalAdaptor = MetalDevice
function gradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :gradient)
Zygote.gradient(f, args...)
Zygote.gradient(f, p)
end

function withgradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated!
Please see the docs for new explicit form.""", :withgradient)
Zygote.withgradient(f, args...)
Zygote.withgradient(f, p)
end

# v0.15 deprecations
Expand Down
10 changes: 5 additions & 5 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ See also [`withgradient`](@ref) to keep the value `f(args...)`.
# Examples
```jldoctest; setup=:(using Zygote)
julia> gradient(*, 2.0, 3.0, 5.0)
```
julia> Flux.gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
julia> Flux.gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
julia> Flux.gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
Expand Down Expand Up @@ -110,7 +110,7 @@ By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other met
# Example
```jldoctest; setup=:(using Zygote)
```
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle

using Zygote
const gradient = Flux.gradient # both Flux & Zygote export this on 0.15
const withgradient = Flux.withgradient

using Pkg
using FiniteDifferences: FiniteDifferences
using Functors: fmapstructure_with_path
Expand Down

0 comments on commit 9740c38

Please sign in to comment.