Skip to content

Commit

Permalink
Improve errors for conv layers (#2404)
Browse files Browse the repository at this point in the history
* better size check for conv layers

* similar for pooling layers

* change to DimensionMismatch
  • Loading branch information
mcabbott authored Mar 21, 2024
1 parent edc1d8c commit c4a0ee4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function _size_check(layer, x::AbstractArray, (d, n)::Pair)
d > 0 || throw(DimensionMismatch(string("layer ", layer,
" expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x))))
size(x, d) == n || throw(DimensionMismatch(string("layer ", layer,
" expects size(input, $d) == $n, but got ", summary(x))))
lazy" expects size(input, $d) == $n, but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _size_check(::Any...)

Expand Down
27 changes: 24 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) =
ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)

function (c::Conv)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
xT = _match_eltype(c, x)
Expand Down Expand Up @@ -331,7 +331,7 @@ end
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, x)
Expand Down Expand Up @@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) =
ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)

function (c::CrossCor)(x::AbstractArray)
_size_check(c, x, ndims(x)-1 => _channels_in(c))
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, x)
Expand All @@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor)
print(io, ")")
end

function _conv_size_check(layer, x::AbstractArray)
ndims(x) == ndims(layer.weight) || throw(DimensionMismatch(LazyString("layer ", layer,
" expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x))))
d = ndims(x)-1
n = _channels_in(layer)
size(x,d) == n || throw(DimensionMismatch(LazyString("layer ", layer,
lazy" expects size(input, $d) == $n, but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any)
"""
AdaptiveMaxPool(out::NTuple)
Expand Down Expand Up @@ -515,6 +524,7 @@ struct AdaptiveMaxPool{S, O}
end

function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T}
_pool_size_check(a, a.out, x)
insize = size(x)[1:end-2]
outsize = a.out
stride = insize outsize
Expand Down Expand Up @@ -556,6 +566,7 @@ struct AdaptiveMeanPool{S, O}
end

function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T}
_pool_size_check(a, a.out, x)
insize = size(x)[1:end-2]
outsize = a.out
stride = insize outsize
Expand Down Expand Up @@ -694,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
end

function (m::MaxPool)(x)
_pool_size_check(m, m.k, x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return maxpool(x, pdims)
end
Expand Down Expand Up @@ -753,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
end

function (m::MeanPool)(x)
_pool_size_check(m, m.k, x)
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
return meanpool(x, pdims)
end
Expand All @@ -763,3 +776,11 @@ function Base.show(io::IO, m::MeanPool)
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
print(io, ")")
end

function _pool_size_check(layer, tup::Tuple, x::AbstractArray)
N = length(tup) + 2
ndims(x) == N || throw(DimensionMismatch(LazyString("layer ", layer,
" expects ndims(input) == ", N, ", but got ", summary(x))))
end
ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any)

0 comments on commit c4a0ee4

Please sign in to comment.