Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Apr 5, 2023
1 parent be56603 commit 761cdf6
Showing 1 changed file with 70 additions and 40 deletions.
110 changes: 70 additions & 40 deletions test/triples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ using Random
using Zygote

const backends = [
StochasticAD.PrunedFIsBackend(),
StochasticAD.PrunedFIsAggressiveBackend(),
StochasticAD.DictFIsBackend(),
PrunedFIsBackend(),
PrunedFIsAggressiveBackend(),
DictFIsBackend(),
]

const backends_smoothed = [StochasticAD.SmoothedFIsBackend()]

@testset "Distributions w.r.t. continuous parameter" begin for backend in vcat(backends,
backends_smoothed,
:smoothing_autodiff)
MAX = 10000
nsamples = 100000
Expand All @@ -39,7 +42,7 @@ const backends = [
if backend isa DictFIsBackend
# Only test dictionary backend on Bernoulli to speed things up. Should still cover interface.
test_cases = test_cases[1:1]
elseif backend == :smoothing_autodiff
elseif backend == :smoothing_autodiff || backend in backends_smoothed
# Only test smoothing backend on each unique distribution once to seed tests up.
test_cases = vcat(test_cases[1:4], test_cases[7])
# Only test unbiasedness of smoothing for linear function
Expand Down Expand Up @@ -104,7 +107,7 @@ end
@testset "Boolean comparisons" begin for backend in backends
tested = falses(2)
while !(all(tested))
st = stochastic_triple(rand Bernoulli, 0.5; backend = backend)
st = stochastic_triple(rand Bernoulli, 0.5; backend)
x = StochasticAD.value(st)
if x == 0
# Ensure errors on unsafe/unsupported boolean comparisons
Expand All @@ -130,7 +133,8 @@ end end
return arr[index]
end
array_index_mean(p) = p / 2 * 3.5 + p / 2 * 5.2 + (1 - p) * 8.4
triple_array_index_deriv = mean(derivative_estimate(array_index, p) for i in 1:100000)
triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend)
for i in 1:10000)
exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p)
@test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2)
# Test indexing into array of stochastic triples with stochastic triple index
Expand All @@ -140,7 +144,8 @@ end end
return arr[index]
end
array_index2_mean(p) = p / 2 * 3.5p + p / 2 * 5.2p + (1 - p) * 8.4p
triple_array_index2_deriv = mean(derivative_estimate(array_index2, p) for i in 1:100000)
triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend)
for i in 1:10000)
exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p)
@test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2)
# Test case where triple and alternate array value are coupled
Expand All @@ -150,7 +155,8 @@ end end
return arr[st + 1]
end
array_index3_mean(p) = -5 * (1 - p) + 1 * p
triple_array_index3_deriv = mean(derivative_estimate(array_index3, p) for i in 1:100000)
triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend)
for i in 1:10000)
exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p)
@test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2)
end
Expand Down Expand Up @@ -284,7 +290,10 @@ end
@test_throws ArgumentError convert(typeof(st1), st2)
end

@testset "Finite perturbation backend interface" begin for backend in backends
@testset "Finite perturbation backend interface" begin for backend in vcat(backends,
backends_smoothed)
# this boolean may need to become more fine-grained in the future
is_smoothed_backend = backend in backends_smoothed
#=
Test the backend interface across the finite perturbation backends,
which is currently a bit implicitly defined.
Expand All @@ -303,7 +312,7 @@ end
for (Δs, V) in ((Δs0, V0), (Δs1, V0), (Δs2, V0), (Δs3, V1))
@test StochasticAD.valtype(Δs) === V
@test Δs isa StochasticAD.similar_type(FIs, V)
@test isempty(Δs)
!is_smoothed_backend && @test isempty(Δs)
@test iszero(derivative_contribution(Δs))
end
# Test creation of a single perturbation
Expand All @@ -312,24 +321,28 @@ end
Δs1 = StochasticAD.similar_new(Δs0, Δ, 3.0)
@test StochasticAD.valtype(Δs1) === typeof(Δ)
@test Δs1 isa StochasticAD.similar_type(FIs, typeof(Δ))
@test !isempty(Δs1)
!is_smoothed_backend && @test !isempty(Δs1)
@test derivative_contribution(Δs1) == 3Δ
# Test StochasticAD.alltrue
@test StochasticAD.alltrue(map(-> true, Δs1))
@test !StochasticAD.alltrue(map(-> false, Δs1))
@test StochasticAD.alltrue(_Δ -> true, Δs1)
@test !StochasticAD.alltrue(_Δ -> false, Δs1) || is_smoothed_backend
# Test map
Δs1_map = Base.map-> Δ^2, Δs1)
@test derivative_contribution(Δs1_map) Δ^2 * 3.0
# We use a dummy deriv here and below. TODO: use a more interesting dummy for better testing.
Δs1_map = Base.map-> Δ^2, Δs1; deriv = identity, out_rep = Δ)
!is_smoothed_backend && @test derivative_contribution(Δs1_map) Δ^2 * 3.0
# Test map_Δs with filter state
Δs1_plus_Δs0 = StochasticAD.map_Δs((Δ, state) -> Δ + StochasticAD.filter_state(Δs0,
state),
Δs1)
@test derivative_contribution(Δs1_plus_Δs0) Δ * 3.0
Δs1_plus_mapped = StochasticAD.map_Δs((Δ, state) -> Δ +
StochasticAD.filter_state(Δs1,
state),
Δs1_map)
@test derivative_contribution(Δs1_plus_mapped) Δ * 3.0 + Δ^2 * 3.0
if !is_smoothed_backend
Δs1_plus_Δs0 = StochasticAD.map_Δs((Δ, state) -> Δ +
StochasticAD.filter_state(Δs0,
state),
Δs1)
@test derivative_contribution(Δs1_plus_Δs0) Δ * 3.0
Δs1_plus_mapped = StochasticAD.map_Δs((Δ, state) -> Δ +
StochasticAD.filter_state(Δs1,
state),
Δs1_map)
@test derivative_contribution(Δs1_plus_mapped) Δ * 3.0 + Δ^2 * 3.0
end
end
# Test coupling
Δ_coupleds = (3, [4.0, 5.0], (2, [3.0, 4.0]))
Expand All @@ -340,19 +353,20 @@ end
Δs2 = StochasticAD.similar_new(Δs0, 1, 2.0) # perturbation 2
# A group of perturbations that all stem from perturbation 1.
Δs_all1 = StochasticAD.structural_map(Δ_coupled) do Δ
Base.map(_Δ -> Δ, Δs1)
Base.map(_Δ -> Δ, Δs1; deriv = identity, out_rep = Δ)
end
# A group of perturbations that all stem from perturbation 2.
Δs_all2 = StochasticAD.structural_map(Δ_coupled) do Δ
Base.map(_Δ -> 2 * Δ, Δs2)
Base.map(_Δ -> 2 * Δ, Δs2; deriv =-> 2δ), out_rep = Δ)
end
# Join them into a single structure that should be coupled
Δs_all = (Δs_all1, Δs_all2)
kwargs = use_get_rep ? (; rep = StochasticAD.get_rep(FIs, Δs_all)) : (;)
if do_combine
return StochasticAD.combine(FIs, Δs_all; kwargs...)
else
return StochasticAD.couple(FIs, Δs_all; kwargs...)
return StochasticAD.couple(FIs, Δs_all; out_rep = (Δ_coupled, Δ_coupled),
kwargs...)
end
end
#=
Expand All @@ -374,7 +388,8 @@ end
true))
function get_contribution()
Δs_coupled = get_Δs_coupled(; use_get_rep)
Δs_coupled_mapped = map(mapfunc, Δs_coupled)
Δs_coupled_mapped = map(mapfunc, Δs_coupled; deriv =-> 1.0),
out_rep = 0.0)
return derivative_contribution(Δs_coupled_mapped)
end
zero_Δ_coupled = StochasticAD.structural_map(zero, Δ_coupled)
Expand All @@ -383,25 +398,34 @@ end
StochasticAD.structural_map(x -> 2x,
Δ_coupled)))
expected_contribution = expected_contribution1 + expected_contribution2
@test isapprox(mean(get_contribution() for i in 1:1000),
expected_contribution; rtol = 5e-2)
if !is_smoothed_backend
@test isapprox(mean(get_contribution() for i in 1:1000),
expected_contribution; rtol = 5e-2)
end
# For a simple sum, this should be equivalent to the combine behaviour.
if check_combine
if check_combine && !is_smoothed_backend
@test isapprox(mean(derivative_contribution(get_Δs_coupled(;
do_combine = true))
for i in 1:1000), expected_contribution;
rtol = 5e-2)
end
# Check scalarize
Δs_coupled2 = StochasticAD.couple(FIs, StochasticAD.scalarize(Δs_coupled))
@test derivative_contribution(map(mapfunc, Δs_coupled))
derivative_contribution(map(mapfunc, Δs_coupled2))
Δs_coupled2 = StochasticAD.couple(FIs,
StochasticAD.scalarize(Δs_coupled;
out_rep = (Δ_coupled,
Δ_coupled)),
out_rep = (Δ_coupled, Δ_coupled))
@test derivative_contribution(map(mapfunc, Δs_coupled; deriv =-> 1.0),
out_rep = 0.0))
derivative_contribution(map(mapfunc, Δs_coupled2; deriv =-> 1.0),
out_rep = 0.0))
end
end
end
end end

@testset "Getting information about stochastic triples" begin for backend in backends
@testset "Getting information about stochastic triples" begin for backend in vcat(backends,
backends_smoothed)
Random.seed!(4321)
f(x) = rand(Bernoulli(x)) + x
st = stochastic_triple(f, 0.5; backend)
Expand All @@ -416,11 +440,17 @@ end end
@test StochasticAD.delta(st) == 1.0
@test StochasticAD.delta(dual) == 1.0

#=
NB: since the implementation of perturbations can be backend-specific, the
below property need not hold in general, but does for the current backends.
=#
@test collect(perturbations(st)) == [(1, 2.0)]
if !(backend in backends_smoothed)
#=
NB: since the implementation of perturbations can be backend-specific, the
below property need not hold in general, but does for the current non-smoothed backends.
=#
@test collect(perturbations(st)) == [(1, 2.0)]
@test derivative_contribution(st) == 3.0
else
# Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution.
@test derivative_contribution(st) == 2.0
end

@test StochasticAD.tag(st) === StochasticAD.Tag{typeof(f), Float64}
@test StochasticAD.valtype(st) === Float64
Expand Down

0 comments on commit 761cdf6

Please sign in to comment.