Skip to content

Commit

Permalink
Embedding and autosize
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 17, 2022
1 parent fa9279c commit 86dc920
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,11 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
"""
autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)

autosizefor(::Type{<:Embedding}, x::AbstractArray) = error(
"@autosize Embeeding(_ => n) cannot work, as this _ is the size of the vocabulary, not an array size")

_replaceunderscore(e, s) = e === :_ ? s : e
_replaceunderscore(ex::Expr, s) = Expr(ex.head, map(a -> _replaceunderscore(a, s), ex.args)...)

Expand Down
10 changes: 5 additions & 5 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,6 @@ end
m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)

@test_broken begin # outputsize fails on Embedding
m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
end

m = @autosize (9,) Dense(_ => div(_,2))
@test randn(9) |> m |> size == (4,)

Expand Down Expand Up @@ -249,6 +244,11 @@ end
# https://github.com/FluxML/Flux.jl/issues/2086
m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false))
@test randn(Float32, 3, 32) |> m |> size == (2, 32)

# Embedding takes a vocab size, not an array size
@test_throws ErrorException @autosize (2, 3) Embedding(_ => 10)
m = @autosize (3,) Chain(Embedding(26 => 10), Dense(_, 4))
@test rand(1:26, 3) |> m |> size == (4, 3)
end

@testset "LazyLayer" begin
Expand Down

0 comments on commit 86dc920

Please sign in to comment.