Skip to content

Commit

Permalink
test: more recurrent testing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 23, 2024
1 parent f135943 commit 04ef36a
Showing 1 changed file with 147 additions and 100 deletions.
247 changes: 147 additions & 100 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

using MLDataDevices

MLDataDevices.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
MLDataDevices.get_device_type(_) = Nothing # FIXME: upstream maybe?
MLDataDevices.Internal.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
MLDataDevices.Internal.get_device_type(_) = Nothing # FIXME: upstream maybe?

function loss_loop(cell, x, p, st)
(y, carry), st_ = cell(x, p, st)
Expand Down Expand Up @@ -43,9 +43,9 @@ end
@jet rnncell((x, carry), ps, st)

if train_state
@test hasproperty(ps, :train_state)
@test hasproperty(ps, :hidden_state)
else
@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :hidden_state)
end

@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -95,8 +95,8 @@ end
@jet lstmcell(x, ps, st)
@jet lstmcell((x, carry), ps, st)

@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :train_memory)
@test !hasproperty(ps, :hidden_state)
@test !hasproperty(ps, :memory)

@test_gradients(loss_loop, lstmcell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end
Expand Down Expand Up @@ -198,7 +198,7 @@ end
@jet grucell(x, ps, st)
@jet grucell((x, carry), ps, st)

@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :hidden_state)

@test_gradients(loss_loop, grucell, x, ps, st; atol=1e-3, rtol=1e-3)
end
Expand Down Expand Up @@ -276,94 +276,138 @@ end
st__ = Lux.update_state(st, :carry, nothing)
@test st__.carry === nothing

@test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol=1e-3, rtol=1e-3)
@test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol=1e-3, rtol=1e-3,
soft_fail=[AutoFiniteDiff()])
end
end
end
end

@testitem "Recurrence" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
@testsetup module RecurrenceTestSetup

using LuxTestUtils, StableRNGs, Test, Lux

function test_recurrence_layer(
mode, aType, dev, ongpu, ordering, _cell, use_bias, train_state)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset for ordering in (BatchLastIndex(), TimeLastIndex())
@testset for _cell in (RNNCell, LSTMCell, GRUCell)
@testset for use_bias in (true, false), train_state in (true, false)
cell = _cell(3 => 5; use_bias, train_state)
rnn = Recurrence(cell; ordering)
rnn_seq = Recurrence(cell; ordering, return_sequence=true)
display(rnn)

# Batched Time Series
@testset "typeof(x): $(typeof(x))" for x in (
randn(rng, Float32, 3, 4, 2) |> aType,
Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType,
[randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType)
# Fix data ordering for testing
if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) 2
x = permutedims(x,
(ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1))
end

ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

@jet rnn(x, ps, st)
@jet rnn_seq(x, ps, st)

@test size(y) == (5, 2)
@test length(y_) == 4
@test all(x -> size(x) == (5, 2), y_)

__f = p -> sum(first(rnn(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])

__f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
end

# Batched Time Series without data batches
@testset "typeof(x): $(typeof(x))" for x in (
randn(rng, Float32, 3, 4) |> aType,
Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType,
[randn(rng, Float32, 3) for _ in 1:4] .|> aType)
ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

@jet rnn(x, ps, st)
@jet rnn_seq(x, ps, st)

@test size(y) == (5,)
@test length(y_) == 4
@test all(x -> size(x) == (5,), y_)

if x isa AbstractMatrix && ordering isa BatchLastIndex
x2 = reshape(x, Val(3))

y2, _ = rnn(x2, ps, st)
@test y == vec(y2)

y2_, _ = rnn_seq(x2, ps, st)
@test all(x -> x[1] == vec(x[2]), zip(y_, y2_))
end

__f = p -> sum(first(rnn(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])

__f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()])
end
end
end
cell = _cell(3 => 5; use_bias, train_state)
rnn = Recurrence(cell; ordering)
display(rnn)
rnn_seq = Recurrence(cell; ordering, return_sequence=true)
display(rnn_seq)

# Batched Time Series
@testset "typeof(x): $(typeof(x))" for x in (
randn(rng, Float32, 3, 4, 2) |> aType,
Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType,
[randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType)
# Fix data ordering for testing
if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) 2
x = permutedims(x,
(ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1))
end

ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

@test size(y) == (5, 2)
@test length(y_) == 4
@test all(x -> size(x) == (5, 2), y_)

__f = ps -> sum(abs2, first(rnn(x, ps, st)))
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])

__f = ps -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, ps, st)))
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
end

# Batched Time Series without data batches
@testset "typeof(x): $(typeof(x))" for x in (
randn(rng, Float32, 3, 4) |> aType,
Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType,
[randn(rng, Float32, 3) for _ in 1:4] .|> aType)
ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

@test size(y) == (5,)
@test length(y_) == 4
@test all(x -> size(x) == (5,), y_)

if x isa AbstractMatrix && ordering isa BatchLastIndex
x2 = reshape(x, Val(3))
y2, _ = rnn(x2, ps, st)
@test y == vec(y2)
y2_, _ = rnn_seq(x2, ps, st)
@test all(x -> x[1] == vec(x[2]), zip(y_, y2_))
end

__f = ps -> sum(abs2, first(rnn(x, ps, st)))
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])

__f = ps -> sum(Base.Fix1(sum, abs2), first(rnn(x, ps, st)))
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
end
end

const ALL_TEST_CONFIGS = Iterators.product(
(BatchLastIndex(), TimeLastIndex()),
(RNNCell, LSTMCell, GRUCell),
(true, false),
(true, false))

const TEST_BLOCKS = collect(Iterators.partition(
ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 4)))

export TEST_BLOCKS, test_recurrence_layer

end

@testitem "Recurrence: Group 1" setup=[
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[1]
test_recurrence_layer(
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
end
end
end

@testitem "Recurrence: Group 2" setup=[
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[2]
test_recurrence_layer(
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
end
end
end

@testitem "Recurrence: Group 3" setup=[
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[3]
test_recurrence_layer(
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
end
end
end

@testitem "Recurrence: Group 4" setup=[
RecurrenceTestSetup, SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
@testset for (ordering, cell, use_bias, train_state) in TEST_BLOCKS[4]
test_recurrence_layer(
mode, aType, dev, ongpu, ordering, cell, use_bias, train_state)
end
end
end

# Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302
@testitem "Recurrence Ordering Check #302" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
rng = StableRNG(12345)
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
encoder = Recurrence(
RNNCell(1 => 1, identity;
init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...),
Expand All @@ -378,7 +422,7 @@ end
end
end

@testitem "Bidirectional" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
@testitem "Bidirectional" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
Expand All @@ -405,17 +449,18 @@ end
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()])
__f = (bi_rnn, x, ps, st) -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
@test_gradients(__f, bi_rnn, x, ps, st; atol=1e-3, rtol=1e-3,
broken_backends=[AutoEnzyme()])

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
__f = (bi_rnn_no_merge, x, ps, st) -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, ps, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()])
@test_gradients(__f, bi_rnn_no_merge, x, ps, st; atol=1e-3,
rtol=1e-3, broken_backends=[AutoEnzyme()])

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
@testset for _backward_cell in (RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
backward_cell = _backward_cell(3 => 5)
bi_rnn = BidirectionalRNN(cell, backward_cell)
Expand All @@ -439,16 +484,18 @@ end
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
__f = (bi_rnn, x, ps, st) -> sum(
Base.Fix1(sum, abs2), first(bi_rnn(x, ps, st)))
@test_gradients(__f, bi_rnn, x, ps, st; atol=1e-3,
rtol=1e-3,
broken_backends=[AutoEnzyme()])

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
__f = (bi_rnn_no_merge, x, ps, st) -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, ps, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@test_gradients(__f, ps; atol=1e-3, rtol=1e-3,
broken_backends=[AutoEnzyme()])
@test_gradients(__f, bi_rnn_no_merge, x, ps, st; atol=1e-3,
rtol=1e-3, broken_backends=[AutoEnzyme()])
end
end
end
Expand Down

0 comments on commit 04ef36a

Please sign in to comment.