diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 9fff43364c..ee223c09b7 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,5 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib expand(_, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -43,19 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, @test eltype(y) == promote_type(Tw, Tx) - @constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) if mode != "amdgpu" && activation !== anonact - @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any else try - @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any catch e e isa ErrorException || rethrow() - @constinferred_broken Zygote.gradient( + @test_broken @inferred(Zygote.gradient( sumabs2conv, activation, weight, x, bias, cdims - ) + )) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 6e65b46547..9689a5ca87 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,5 +1,5 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @constinferred fused_dense_bias_activation(activation, w, x, bias) + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any @jet fused_dense_bias_activation(activation, w, x, bias) atol = 1.0f-3 rtol = 1.0f-3 if activation !== anonact - @constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias) + @test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any end skip_backends = [] diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index d47c542d63..ea8cb02e2d 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -69,7 +69,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end end - @constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa + Any @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @@ -88,9 +89,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon) + @test @inferred(Zygote.gradient( + sumabs2first, x, scale, bias, rm, rv, training, act, epsilon)) isa Any end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 891c68715b..6302bc6dd9 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs using LuxTestUtils: check_approx function setup_groupnorm(rng, aType, T, sz, affine) @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) @test ∂bias≈∂bias_simple atol=atol rtol=rtol end - @constinferred groupnorm(x, scale, bias, groups, act, epsilon) + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) if anonact !== act - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon) + @test @inferred(Zygote.gradient( + sumabs2groupnorm, x, scale, bias, groups, act, epsilon)) isa Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index cc8b1e81b6..a0e9e21308 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib is_training(::Val{training}) where {training} = training @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) atol = 1.0f-2 rtol = 1.0f-2 - @constinferred instancenorm(x, scale, bias, training, act, epsilon) + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any end @test y isa aType{T, length(sz)} @@ -46,13 +46,14 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) - @constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + @test @inferred(instancenorm( + x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) - lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( - x, sc, b, rm, rv, Val(true), act, m, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon) + @test @inferred(Zygote.gradient( + sumabs2instancenorm, x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa + Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 940e95c06b..43b9896157 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -40,7 +40,7 @@ function run_layernorm_testing_core( epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - @constinferred layernorm(x, scale, bias, act, dims, epsilon) + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) @@ -60,8 +60,8 @@ function run_layernorm_testing_core( soft_fail=[AutoFiniteDiff()]) if anonact !== act - lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @constinferred Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon) + @test @inferred(Zygote.gradient( + sumabs2layernorm, x, scale, bias, act, dims, epsilon)) isa Any end end