From 542a7521ed4054b072f02bf1034e84a468920e9d Mon Sep 17 00:00:00 2001 From: WT Date: Thu, 3 Jun 2021 09:31:17 +0100 Subject: [PATCH 1/3] Bump ChainRulesCore dep --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fada76e0c..321086107 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Adapt = "2, 3.2" -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.9.45, 0.10" Compat = "3.14" Requires = "0.5, 1.0" julia = "1.6" From 36bff7136a3ce6ed32bd5b26e6eab237bdc5f500 Mon Sep 17 00:00:00 2001 From: WT Date: Thu, 3 Jun 2021 09:31:29 +0100 Subject: [PATCH 2/3] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 321086107..60cfc58db 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.20" +version = "0.7.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 5c7168962995b207775f2796c26c9164e0725d03 Mon Sep 17 00:00:00 2001 From: WT Date: Thu, 3 Jun 2021 09:31:49 +0100 Subject: [PATCH 3/3] Replace NO_FIELDS with NoTangent() --- src/activations.jl | 4 ++-- src/batched/batchedadjtrans.jl | 6 +++--- src/batched/batchedmul.jl | 2 +- src/conv.jl | 6 +++--- src/gather.jl | 2 +- src/padding.jl | 2 +- src/pooling.jl | 4 ++-- src/scatter.jl | 4 ++-- src/softmax.jl | 4 ++-- src/upsample.jl | 8 ++++---- 10 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 411209975..89ecc7b83 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -260,7 +260,7 @@ for (f, df) in UNARY_ACTS ::typeof($f), x::Numeric) Ω = $f.(x) function $pullback(Δ) - NO_FIELDS, NO_FIELDS, @.(Δ * $df) + NoTangent(), NoTangent(), @.(Δ * $df) end return Ω, $pullback end @@ -280,7 +280,7 @@ for (f, df1, df2) in BINARY_ACTS x1::Numeric, x2::Numeric) Ω = $f.(x1, x2) function $pullback(Δ) - NO_FIELDS, NO_FIELDS, @.(Δ * $df1), @.(Δ * $df2) + NoTangent(), NoTangent(), @.(Δ * $df1), @.(Δ * $df2) end return Ω, $pullback end diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index f9cc8773d..50505baf7 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -94,13 +94,13 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} = # Gradients function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3}) - b_transpose_back(Δ) = (NO_FIELDS, batched_transpose(Δ)) + b_transpose_back(Δ) = (NoTangent(), batched_transpose(Δ)) batched_transpose(A), b_transpose_back end function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3}) - b_adjoint_back(Δ) = (NO_FIELDS, batched_adjoint(Δ)) + b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(Δ)) batched_adjoint(A), b_adjoint_back end adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x))) -adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x))) \ No newline at end of file +adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x))) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 6e6f8e3fa..c217c801c 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -100,7 +100,7 @@ function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArra tmp = batched_mul(batched_adjoint(A), Δ) size(B,3) == 1 ? sum(tmp, dims=3) : tmp end - return (NO_FIELDS, Athunk, Bthunk) + return (NoTangent(), Athunk, Bthunk) end batched_mul(A, B), batched_mul_pullback end diff --git a/src/conv.jl b/src/conv.jl index 57465333a..bb736fe9d 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -224,7 +224,7 @@ for conv in [:conv, :depthwiseconv] function $conv_pullback(Δ) Δ = colmajor(Δ) return ( - NO_FIELDS, + NoTangent(), @thunk($∇conv_data(Δ, w, cdims, kw...)), @thunk($∇conv_filter(x, Δ, cdims, kw...)), NoTangent(), @@ -237,7 +237,7 @@ for conv in [:conv, :depthwiseconv] function $∇conv_data_pullback(Δ) Δ = colmajor(Δ) return ( - NO_FIELDS, + NoTangent(), @thunk($conv(Δ, w, cdims, kw...)), @thunk($∇conv_filter(Δ, x, cdims, kw...)), NoTangent(), @@ -268,4 +268,4 @@ end # return ∇conv_filter_nnpack(x, dy, cdims; kwargs...) # end # end -######################################################## \ No newline at end of file +######################################################## diff --git a/src/gather.jl b/src/gather.jl index 57b69a17b..2872ab9f1 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -79,6 +79,6 @@ end function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(dst, src, idx) src_size = size(src) - gather!_pullback(Δ) = (NO_FIELDS, NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent()) + gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent()) y, gather!_pullback end diff --git a/src/padding.jl b/src/padding.jl index 9258f9e38..b6a614279 100644 --- a/src/padding.jl +++ b/src/padding.jl @@ -155,7 +155,7 @@ function rrule(::typeof(pad_constant), x::AbstractArray{T,N}, function pad_constant_pullback(Δ) p = gen_pad(pad, dims, N) outsize, center = size_and_center(x, p) - (NO_FIELDS, @thunk(Δ[center...]), NoTangent(), NoTangent(),) + (NoTangent(), @thunk(Δ[center...]), NoTangent(), NoTangent(),) end return y, pad_constant_pullback end diff --git a/src/pooling.jl b/src/pooling.jl index 572436b79..d69ea1f16 100644 --- a/src/pooling.jl +++ b/src/pooling.jl @@ -173,7 +173,7 @@ for pool in [:maxpool, :meanpool] pullback = Symbol(pool, :_pullback) @eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...) Ω = $pool(x, pdims; kw...) - $pullback(Δ) = (NO_FIELDS, $∇pool(Δ, Ω, x, pdims; kw...), NoTangent()) + $pullback(Δ) = (NoTangent(), $∇pool(Δ, Ω, x, pdims; kw...), NoTangent()) return Ω, $pullback end -end \ No newline at end of file +end diff --git a/src/scatter.jl b/src/scatter.jl index 298187c6b..2c0f5d8eb 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -152,12 +152,12 @@ end function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) dst_old = copy(dst) scatter!(op, dst, src, idx) - scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent()) + scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent()) dst, scatter!_pullback end function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray) y = scatter(op, src, idx) - scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, y, src, idx), NoTangent()) + scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, Δ, y, src, idx), NoTangent()) y, scatter_pullback end diff --git a/src/softmax.jl b/src/softmax.jl index 5b5a3f7e9..b51cbc789 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -78,7 +78,7 @@ end function rrule(::typeof(softmax), xs; dims=1) y = softmax(xs; dims=dims) - softmax_pullback(Δ) = (NO_FIELDS, ∇softmax(Δ, xs, y, dims=dims)) + softmax_pullback(Δ) = (NoTangent(), ∇softmax(Δ, xs, y, dims=dims)) return y, softmax_pullback end @@ -125,7 +125,7 @@ end function rrule(::typeof(logsoftmax), xs; dims=1) y = logsoftmax(xs; dims=dims) - logsoftmax_pullback(Δ) = (NO_FIELDS, ∇logsoftmax(Δ, xs, y, dims=dims)) + logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(Δ, xs, y, dims=dims)) return y, logsoftmax_pullback end diff --git a/src/upsample.jl b/src/upsample.jl index f2623c79a..f1c6f7ecf 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -77,7 +77,7 @@ end function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple) Ω = upsample_nearest(x, s) - upsample_nearest_pullback(Δ) = (NO_FIELDS, ∇upsample_nearest(Δ, s), NoTangent()) + upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(Δ, s), NoTangent()) return Ω, upsample_nearest_pullback end @@ -203,7 +203,7 @@ end function rrule(::typeof(upsample_linear), x; size) Ω = upsample_linear(x; size=size) function upsample_linear_pullback(Δ) - (NO_FIELDS, ∇upsample_linear(Δ; size=Base.size(x,1))) + (NoTangent(), ∇upsample_linear(Δ; size=Base.size(x,1))) end return Ω, upsample_linear_pullback end @@ -368,7 +368,7 @@ end function rrule(::typeof(upsample_bilinear), x; size) Ω = upsample_bilinear(x; size=size) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2)))) + (NoTangent(), ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2)))) end return Ω, upsample_bilinear_pullback end @@ -518,7 +518,7 @@ end function rrule(::typeof(upsample_trilinear), x; size) Ω = upsample_trilinear(x; size=size) function upsample_trilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3)))) + (NoTangent(), ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3)))) end return Ω, upsample_trilinear_pullback end