diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ac85827a41..ef81c30872 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -193,7 +193,7 @@ function _size_check(layer, x::AbstractArray, (d, n)::Pair) d > 0 || throw(DimensionMismatch(string("layer ", layer, " expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x)))) size(x, d) == n || throw(DimensionMismatch(string("layer ", layer, - " expects size(input, $d) == $n, but got ", summary(x)))) + lazy" expects size(input, $d) == $n, but got ", summary(x)))) end ChainRulesCore.@non_differentiable _size_check(::Any...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4e6044dcfb..fdf3c756e9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) = ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) xT = _match_eltype(c, x) @@ -331,7 +331,7 @@ end ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) xT = _match_eltype(c, x) @@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) = ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) - _size_check(c, x, ndims(x)-1 => _channels_in(c)) + _conv_size_check(c, x) σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) xT = _match_eltype(c, x) @@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor) print(io, ")") end +function _conv_size_check(layer, x::AbstractArray) + ndims(x) == ndims(layer.weight) || throw(DimensionMismatch(LazyString("layer ", layer, + " expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x)))) + d = ndims(x)-1 + n = _channels_in(layer) + size(x,d) == n || throw(DimensionMismatch(LazyString("layer ", layer, + lazy" expects size(input, $d) == $n, but got ", summary(x)))) +end +ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any) """ AdaptiveMaxPool(out::NTuple) @@ -515,6 +524,7 @@ struct AdaptiveMaxPool{S, O} end function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T} + _pool_size_check(a, a.out, x) insize = size(x)[1:end-2] outsize = a.out stride = insize .÷ outsize @@ -556,6 +566,7 @@ struct AdaptiveMeanPool{S, O} end function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T} + _pool_size_check(a, a.out, x) insize = size(x)[1:end-2] outsize = a.out stride = insize .÷ outsize @@ -694,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N end function (m::MaxPool)(x) + _pool_size_check(m, m.k, x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return maxpool(x, pdims) end @@ -753,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N end function (m::MeanPool)(x) + _pool_size_check(m, m.k, x) pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return meanpool(x, pdims) end @@ -763,3 +776,11 @@ function Base.show(io::IO, m::MeanPool) m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) print(io, ")") end + +function _pool_size_check(layer, tup::Tuple, x::AbstractArray) + N = length(tup) + 2 + ndims(x) == N || throw(DimensionMismatch(LazyString("layer ", layer, + " expects ndims(input) == ", N, ", but got ", summary(x)))) +end +ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any) +