Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve errors for conv layers #2404

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function _size_check(layer, x::AbstractArray, (d, n)::Pair)
d > 0 || throw(DimensionMismatch(string("layer ", layer,
" expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x))))
size(x, d) == n || throw(DimensionMismatch(string("layer ", layer,
" expects size(input, $d) == $n, but got ", summary(x))))
lazy" expects size(input, $d) == $n, but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _size_check(::Any...)

Expand Down
27 changes: 24 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) =
ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)

function (c::Conv)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
xT = _match_eltype(c, x)
Expand Down Expand Up @@ -331,7 +331,7 @@ end
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, x)
Expand Down Expand Up @@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) =
ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)

function (c::CrossCor)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, x)
Expand All @@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor)
print(io, ")")
end

function _conv_size_check(layer, x::AbstractArray)
ndims(x) == ndims(layer.weight) || throw(DimensionMismatch(LazyString("layer ", layer,
" expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x))))
d = ndims(x)-1
n = _channels_in(layer)
size(x,d) == n || throw(DimensionMismatch(LazyString("layer ", layer,
lazy" expects size(input, $d) == $n, but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any)
"""
AdaptiveMaxPool(out::NTuple)

Expand Down Expand Up @@ -515,6 +524,7 @@ struct AdaptiveMaxPool{S, O}
end

function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T}
_pool_size_check(a, a.out, x)
insize = size(x)[1:end-2]
outsize = a.out
stride = insize .÷ outsize
Expand Down Expand Up @@ -556,6 +566,7 @@ struct AdaptiveMeanPool{S, O}
end

function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T}
_pool_size_check(a, a.out, x)
insize = size(x)[1:end-2]
outsize = a.out
stride = insize .÷ outsize
Expand Down Expand Up @@ -694,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
end

function (m::MaxPool)(x)
_pool_size_check(m, m.k, x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
Expand Down Expand Up @@ -753,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
end

function (m::MeanPool)(x)
_pool_size_check(m, m.k, x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
Expand All @@ -763,3 +776,11 @@ function Base.show(io::IO, m::MeanPool)
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
print(io, ")")
end

function _pool_size_check(layer, tup::Tuple, x::AbstractArray)
N = length(tup) + 2
ndims(x) == N || throw(DimensionMismatch(LazyString("layer ", layer,
" expects ndims(input) == ", N, ", but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any)

Loading