Skip to content

Commit

Permalink
Support conv_direct! on custom datatypes (#592)
Browse files Browse the repository at this point in the history
* Add test for unusual input datatypes

* Add fix to `conv_direct!`

* Only set y to zero if beta is false or zero

* Test output eltype
  • Loading branch information
adrhill authored Jun 19, 2024
1 parent 85b17cf commit d423576
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ function conv_direct!(
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
padded_regions, central_region = calc_padding_regions(cdims)

# Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490)
if iszero(beta)
y = fill!(y, zero(yT))
end

# Start with the central region
w_region, h_region, d_region = central_region
@inbounds for batch in 1:size(x, 5),
Expand Down
39 changes: 39 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using NNlib, Test
using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,
stride, padding, dilation, flipkernel, output_size,
groupcount
using Random: AbstractRNG, SamplerType

@testset "ConvDims" begin
for T in (DenseConvDims, DepthwiseConvDims)
Expand Down Expand Up @@ -865,6 +866,44 @@ end
@test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size
end

# https://github.com/FluxML/NNlib.jl/issues/490
# https://github.com/FluxML/NNlib.jl/issues/405
@testset "conv_direct! - Unusual input types" begin
# Create test type that can't be indexed when undefined.
# This simulates the worst-case scenario for custom types.
struct MyFloat <: Real
set::Set{Float32}
end

# Test that direct indexing fails when undefined.
v = Array{MyFloat}(undef, 3)
@test_throws UndefRefError v[1]

# Define minimal set of functions required for conv_direct!
MyFloat(x::MyFloat) = x
MyFloat(x::Real) = MyFloat(Set(Float32(x)))

Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set))
Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set))
Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat
Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32))
Base.zero(::MyFloat) = MyFloat(zero(Float32))
Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32))

# Test conv_direct!
x_size = (6, 7, 8, 5, 3)
y_size = (5, 6, 7, 4, 3)
w_size = (2, 2, 2, 5, 4)
x = rand(MyFloat, x_size);
w = randn(Float32, w_size);
y = Array{MyFloat}(undef, y_size...);
cdims = DenseConvDims(x_size, w_size)
y_out = NNlib.conv_direct!(y, x, w, cdims)

@test eltype(y_out) == MyFloat
@test size(y_out) == y_size
end

@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
Expand Down

0 comments on commit d423576

Please sign in to comment.