From d423576c1639ad57fd179c42c6d3f9eb14f8e157 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Wed, 19 Jun 2024 21:41:46 +0200 Subject: [PATCH] Support `conv_direct!` on custom datatypes (#592) * Add test for unusual input datatypes * Add fix to `conv_direct!` * Only set y to zero if beta is false or zero * Test output eltype --- src/impl/conv_direct.jl | 5 +++++ test/conv.jl | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 9f12f1dc9..497f2e929 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -81,6 +81,11 @@ function conv_direct!( # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) + # Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490) + if iszero(beta) + y = fill!(y, zero(yT)) + end + # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x, 5), diff --git a/test/conv.jl b/test/conv.jl index dce01771a..492de2cc7 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -2,6 +2,7 @@ using NNlib, Test using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier, stride, padding, dilation, flipkernel, output_size, groupcount +using Random: AbstractRNG, SamplerType @testset "ConvDims" begin for T in (DenseConvDims, DepthwiseConvDims) @@ -865,6 +866,44 @@ end @test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size end +# https://github.com/FluxML/NNlib.jl/issues/490 +# https://github.com/FluxML/NNlib.jl/issues/405 +@testset "conv_direct! - Unusual input types" begin + # Create test type that can't be indexed when undefined. + # This simulates the worst-case scenario for custom types. + struct MyFloat <: Real + set::Set{Float32} + end + + # Test that direct indexing fails when undefined. + v = Array{MyFloat}(undef, 3) + @test_throws UndefRefError v[1] + + # Define minimal set of functions required for conv_direct! + MyFloat(x::MyFloat) = x + MyFloat(x::Real) = MyFloat(Set(Float32(x))) + + Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set)) + Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set)) + Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat + Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32)) + Base.zero(::MyFloat) = MyFloat(zero(Float32)) + Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32)) + + # Test conv_direct! + x_size = (6, 7, 8, 5, 3) + y_size = (5, 6, 7, 4, 3) + w_size = (2, 2, 2, 5, 4) + x = rand(MyFloat, x_size); + w = randn(Float32, w_size); + y = Array{MyFloat}(undef, y_size...); + cdims = DenseConvDims(x_size, w_size) + y_out = NNlib.conv_direct!(y, x, w, cdims) + + @test eltype(y_out) == MyFloat + @test size(y_out) == y_size +end + @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3)