Skip to content

Commit

Permalink
Replace NO_FIELDS with NoTangent()
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jun 3, 2021
1 parent 36bff71 commit 5c71689
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ for (f, df) in UNARY_ACTS
::typeof($f), x::Numeric)
Ω = $f.(x)
function $pullback(Δ)
NO_FIELDS, NO_FIELDS, @.* $df)
NoTangent(), NoTangent(), @.* $df)
end
return Ω, $pullback
end
Expand All @@ -280,7 +280,7 @@ for (f, df1, df2) in BINARY_ACTS
x1::Numeric, x2::Numeric)
Ω = $f.(x1, x2)
function $pullback(Δ)
NO_FIELDS, NO_FIELDS, @.* $df1), @.* $df2)
NoTangent(), NoTangent(), @.* $df1), @.* $df2)
end
return Ω, $pullback
end
Expand Down
6 changes: 3 additions & 3 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =

# Gradients
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
b_transpose_back(Δ) = (NO_FIELDS, batched_transpose(Δ))
b_transpose_back(Δ) = (NoTangent(), batched_transpose(Δ))
batched_transpose(A), b_transpose_back
end
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
b_adjoint_back(Δ) = (NO_FIELDS, batched_adjoint(Δ))
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(Δ))
batched_adjoint(A), b_adjoint_back
end

adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x)))
adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x)))
adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x)))
2 changes: 1 addition & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArra
tmp = batched_mul(batched_adjoint(A), Δ)
size(B,3) == 1 ? sum(tmp, dims=3) : tmp
end
return (NO_FIELDS, Athunk, Bthunk)
return (NoTangent(), Athunk, Bthunk)
end
batched_mul(A, B), batched_mul_pullback
end
Expand Down
6 changes: 3 additions & 3 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ for conv in [:conv, :depthwiseconv]
function $conv_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
NoTangent(),
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
NoTangent(),
Expand All @@ -237,7 +237,7 @@ for conv in [:conv, :depthwiseconv]
function $∇conv_data_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
NoTangent(),
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
NoTangent(),
Expand Down Expand Up @@ -268,4 +268,4 @@ end
# return ∇conv_filter_nnpack(x, dy, cdims; kwargs...)
# end
# end
########################################################
########################################################
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ end
function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NO_FIELDS, NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent())
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent())
y, gather!_pullback
end
2 changes: 1 addition & 1 deletion src/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function rrule(::typeof(pad_constant), x::AbstractArray{T,N},
function pad_constant_pullback(Δ)
p = gen_pad(pad, dims, N)
outsize, center = size_and_center(x, p)
(NO_FIELDS, @thunk(Δ[center...]), NoTangent(), NoTangent(),)
(NoTangent(), @thunk(Δ[center...]), NoTangent(), NoTangent(),)
end
return y, pad_constant_pullback
end
Expand Down
4 changes: 2 additions & 2 deletions src/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ for pool in [:maxpool, :meanpool]
pullback = Symbol(pool, :_pullback)
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = $pool(x, pdims; kw...)
$pullback(Δ) = (NO_FIELDS, $∇pool(Δ, Ω, x, pdims; kw...), NoTangent())
$pullback(Δ) = (NoTangent(), $∇pool(Δ, Ω, x, pdims; kw...), NoTangent())
return Ω, $pullback
end
end
end
4 changes: 2 additions & 2 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ end
function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dst_old = copy(dst)
scatter!(op, dst, src, idx)
scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent())
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent())
dst, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray)
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, y, src, idx), NoTangent())
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, Δ, y, src, idx), NoTangent())
y, scatter_pullback
end
4 changes: 2 additions & 2 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

function rrule(::typeof(softmax), xs; dims=1)
y = softmax(xs; dims=dims)
softmax_pullback(Δ) = (NO_FIELDS, ∇softmax(Δ, xs, y, dims=dims))
softmax_pullback(Δ) = (NoTangent(), ∇softmax(Δ, xs, y, dims=dims))
return y, softmax_pullback
end

Expand Down Expand Up @@ -125,7 +125,7 @@ end

function rrule(::typeof(logsoftmax), xs; dims=1)
y = logsoftmax(xs; dims=dims)
logsoftmax_pullback(Δ) = (NO_FIELDS, ∇logsoftmax(Δ, xs, y, dims=dims))
logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(Δ, xs, y, dims=dims))
return y, logsoftmax_pullback
end

Expand Down
8 changes: 4 additions & 4 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end

function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
Ω = upsample_nearest(x, s)
upsample_nearest_pullback(Δ) = (NO_FIELDS, ∇upsample_nearest(Δ, s), NoTangent())
upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(Δ, s), NoTangent())
return Ω, upsample_nearest_pullback
end

Expand Down Expand Up @@ -203,7 +203,7 @@ end
function rrule(::typeof(upsample_linear), x; size)
Ω = upsample_linear(x; size=size)
function upsample_linear_pullback(Δ)
(NO_FIELDS, ∇upsample_linear(Δ; size=Base.size(x,1)))
(NoTangent(), ∇upsample_linear(Δ; size=Base.size(x,1)))
end
return Ω, upsample_linear_pullback
end
Expand Down Expand Up @@ -368,7 +368,7 @@ end
function rrule(::typeof(upsample_bilinear), x; size)
Ω = upsample_bilinear(x; size=size)
function upsample_bilinear_pullback(Δ)
(NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))
(NoTangent(), ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))
end
return Ω, upsample_bilinear_pullback
end
Expand Down Expand Up @@ -518,7 +518,7 @@ end
function rrule(::typeof(upsample_trilinear), x; size)
Ω = upsample_trilinear(x; size=size)
function upsample_trilinear_pullback(Δ)
(NO_FIELDS, ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
(NoTangent(), ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
end
return Ω, upsample_trilinear_pullback
end
Expand Down

0 comments on commit 5c71689

Please sign in to comment.