Skip to content

Commit

Permalink
Merge pull request #81 from gaurav-arya/ag-smoothing
Browse files Browse the repository at this point in the history
Add stochastic triple smoothing backend
  • Loading branch information
gaurav-arya authored Apr 6, 2023
2 parents 8a67258 + 1a85493 commit 1c4395b
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 125 deletions.
2 changes: 2 additions & 0 deletions benchmark/game_of_life.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ suite["PrunedFIs"] = @benchmarkable derivative_estimate($play, $p;
backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($play, $p;
backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($play, $p;
backend = SmoothedFIsBackend())

end
2 changes: 2 additions & 0 deletions benchmark/random_walk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ suite["PrunedFIs"] = @benchmarkable derivative_estimate($fX, $p;
backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($fX, $p;
backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($fX, $p;
backend = SmoothedFIsBackend())
forwarddiff_func = p -> fX(p; hardcode_leftright_step = true)
suite["ForwardDiff_smoothing"] = @benchmarkable derivative($forwarddiff_func, $p)

Expand Down
2 changes: 1 addition & 1 deletion src/StochasticAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module StochasticAD
export stochastic_triple, derivative_contribution, perturbations # For working with stochastic triples
export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality
export new_weight # Particle resampling
export PrunedFIsBackend, PrunedFIsAggressiveBackend, DictFIsBackend
export PrunedFIsBackend, PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend

### Imports

Expand Down
17 changes: 9 additions & 8 deletions src/backends/dict.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DictFIsModule

export DictFIsBackend
export DictFIsBackend, DictFIs

import ..StochasticAD
using Dictionaries
Expand Down Expand Up @@ -99,11 +99,7 @@ function StochasticAD.map_Δs(f, Δs::DictFIs; kwargs...)
DictFIs(dict, Δs.state)
end

function StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V}
haskey(Δs.dict, key) ? Δs.dict[key] : zero(V)
end

StochasticAD.alltrue(Δs::DictFIs{Bool}) = all(Δs.dict)
StochasticAD.alltrue(f, Δs::DictFIs) = all(map(f, collect(Δs.dict)))

### Coupling

Expand All @@ -117,7 +113,8 @@ function StochasticAD.get_rep(::Type{<:DictFIs}, Δs_all)
end

function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all))
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing)
all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs
keys(Δs.dict)
end
Expand All @@ -138,7 +135,7 @@ function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;
DictFIs(Δs_combined_dict, rep.state)
end

function StochasticAD.scalarize(Δs::DictFIs)
function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing)
tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2)
Δ_all_allkeys = foldl(tupleify, values(Δs.dict))
Δ_all_rep = first(values(Δs.dict))
Expand All @@ -148,6 +145,10 @@ function StochasticAD.scalarize(Δs::DictFIs)
end
end

function StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V}
haskey(Δs.dict, key) ? Δs.dict[key] : zero(V)
end

### Miscellaneous

StochasticAD.similar_type(::Type{<:DictFIs}, V::Type) = DictFIs{V}
Expand Down
16 changes: 9 additions & 7 deletions src/backends/pruned.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module PrunedFIsModule

import ..StochasticAD

export PrunedFIsBackend
export PrunedFIsBackend, PrunedFIs

"""
PrunedFIsBackend <: StochasticAD.AbstractFIsBackend
Expand Down Expand Up @@ -82,9 +82,7 @@ isapproxzero(Δs::PrunedFIs) = isempty(Δs) || isapprox(Δs.Δ, zero(Δs.Δ))
pruned_value(Δs::PrunedFIs{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}
Δs.state === state ? pruned_value(Δs) : zero(V)
end

StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight
StochasticAD.perturbations(Δs::PrunedFIs) = ((pruned_value(Δs), Δs.state.weight),)

Expand All @@ -94,7 +92,7 @@ function StochasticAD.map_Δs(f, Δs::PrunedFIs; kwargs...)
PrunedFIs(f(Δs.Δ, Δs.state), Δs.tag, Δs.state)
end

StochasticAD.alltrue(Δs::PrunedFIs{Bool}) = Δs.Δ
StochasticAD.alltrue(f, Δs::PrunedFIs) = f(Δs.Δ)

### Coupling

Expand Down Expand Up @@ -135,7 +133,7 @@ end
# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
# rep is unused.
function StochasticAD.couple(::Type{<:PrunedFIs}, Δs_all; rep = nothing)
function StochasticAD.couple(::Type{<:PrunedFIs}, Δs_all; rep = nothing, out_rep = nothing)
state = get_pruned_state(Δs_all)
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIs(Δ_coupled, state.active_tag, state)
Expand All @@ -148,12 +146,16 @@ function StochasticAD.combine(::Type{<:PrunedFIs}, Δs_all; rep = nothing)
PrunedFIs(Δ_combined, state.active_tag, state)
end

function StochasticAD.scalarize(Δs::PrunedFIs)
function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing)
return StochasticAD.structural_map(Δs.Δ) do Δ
return PrunedFIs(Δ, Δs.tag, Δs.state)
end
end

function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}
Δs.state === state ? pruned_value(Δs) : zero(V)
end

### Miscellaneous

StochasticAD.similar_type(::Type{<:PrunedFIs}, V::Type) = PrunedFIs{V}
Expand Down
12 changes: 7 additions & 5 deletions src/backends/pruned_aggressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module PrunedFIsAggressiveModule

import ..StochasticAD

export PrunedFIsAggressiveBackend
export PrunedFIsAggressiveBackend, PrunedFIsAggressive

"""
PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend
Expand Down Expand Up @@ -91,7 +91,6 @@ Base.iszero(Δs::PrunedFIsAggressive) = isempty(Δs) || iszero(Δs.Δ)

# we lazily prune, so check if empty first
pruned_value(Δs::PrunedFIsAggressive{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ
StochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs)

function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive)
pruned_value(Δs) * Δs.state.weight
Expand All @@ -105,7 +104,7 @@ function StochasticAD.map_Δs(f, Δs::PrunedFIsAggressive; kwargs...)
PrunedFIsAggressive(f(Δs.Δ, nothing), Δs.tag, Δs.state)
end

StochasticAD.alltrue(Δs::PrunedFIsAggressive{Bool}) = Δs.Δ
StochasticAD.alltrue(f, Δs::PrunedFIsAggressive) = f(Δs.Δ)

### Coupling

Expand All @@ -118,7 +117,8 @@ end
# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all))
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing)
state = rep.state
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIsAggressive(Δ_coupled, state.active_tag, state)
Expand All @@ -132,12 +132,14 @@ function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
PrunedFIsAggressive(Δ_combined, state.active_tag, state)
end

function StochasticAD.scalarize(Δs::PrunedFIsAggressive)
function StochasticAD.scalarize(Δs::PrunedFIsAggressive; out_rep = nothing)
return StochasticAD.structural_map(Δs.Δ) do Δ
return PrunedFIsAggressive(Δ, Δs.tag, Δs.state)
end
end

StochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs)

### Miscellaneous

StochasticAD.similar_type(::Type{<:PrunedFIsAggressive}, V::Type) = PrunedFIsAggressive{V}
Expand Down
94 changes: 80 additions & 14 deletions src/backends/smoothed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,102 @@ module SmoothedFIsModule

import ..StochasticAD

export SmoothedFIs
export SmoothedFIsBackend, SmoothedFIs

"""
SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend
A backend algorithm that smooths perturbations togethers.
"""
struct SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend end

"""
SmoothedFIs{V} <: StochasticAD.AbstractFIs{V}
A backend that smooths perturbations together.
The full backend interface is not supported, rather only the functions necessary for defining chain rules.
The implementing backend structure for SmoothedFIsBackend.
"""
struct SmoothedFIs{V, Vfloat} <: StochasticAD.AbstractFIs{V}
δ::Vfloat
# TODO: make type of δ generic
struct SmoothedFIs{V, V_float} <: StochasticAD.AbstractFIs{V}
δ::V_float
function SmoothedFIs{V}(δ) where {V}
# hardcode Float64 representation for now, for simplicity.
δ_f64 = StochasticAD.structural_map(Base.Fix1(convert, Float64), δ)
return new{V, typeof(δ_f64)}(δ_f64)
end
end

SmoothedFIs{V}::Vfloat) where {V, Vfloat} = SmoothedFIs{V, Vfloat}(δ)
### Empty / no perturbation

StochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(0.0)
Base.empty(::Type{<:SmoothedFIs{V}}) where {V} = SmoothedFIs{V}(0.0)
Base.empty(Δs::SmoothedFIs) = empty(typeof(Δs))

### Create a new perturbation with infinitesimal probability

StochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(zero(float(V)))
function StochasticAD.similar_new(::SmoothedFIs, Δ::V, w::Real) where {V}
SmoothedFIs{V}(float(Δ) * w)
SmoothedFIs{V}* w)
end

StochasticAD.new_Δs_strategy(::SmoothedFIs) = StochasticAD.TwoSidedStrategy()

### Scale a perturbation

function StochasticAD.scale(Δs::SmoothedFIs{V}, scale::Real) where {V}
SmoothedFIs{V}(Δs.δ * scale)
end

### Create Δs backend for the first stochastic triple of computation

StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0)

### Convert type of a backend

function (::Type{<:SmoothedFIs{V}})(Δs::SmoothedFIs) where {V}
SmoothedFIs{V}(Δs.δ)
end
(::Type{SmoothedFIs{V}})(Δs::SmoothedFIs) where {V} = SmoothedFIs{V}(Δs.δ)

### Getting information about perturbations

Base.isempty(Δs::SmoothedFIs) = false
Base.iszero(Δs::SmoothedFIs) = iszero(Δs.δ)
Base.iszero(Δs::SmoothedFIs{<:Tuple}) = all(iszero.(Δs.δ))
StochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ

### Unary propagation

function StochasticAD.map_Δs(f, Δs::SmoothedFIs; deriv, out_rep)
SmoothedFIs{typeof(out_rep)}(deriv(Δs.δ))
end

StochasticAD.alltrue(f, Δs::SmoothedFIs) = true

### Coupling

StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = first(Δs_all)

function StochasticAD.couple(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep)
SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all))
end

function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing)
V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all)))
Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all))
# TODO: using eltype below will not work in general, and the proper fix
# could be a caller-provided type. This is not yet needed for this function's
# limited internal use.
eltype(Δs_all)(Δ_combined)
SmoothedFIs{V_out}(Δ_combined)
end

StochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ
function StochasticAD.scalarize(Δs::SmoothedFIs; out_rep)
return StochasticAD.structural_map(out_rep, Δs.δ) do out, δ
return SmoothedFIs{typeof(out)}(δ)
end
end

### Miscellaneous

StochasticAD.similar_type(::Type{<:SmoothedFIs}, V::Type) = SmoothedFIs{V}
StochasticAD.valtype(::Type{<:SmoothedFIs{V}}) where {V} = V

function Base.show(io::IO, mime::MIME"text/plain", Δs::SmoothedFIs)
function Base.show(io::IO, Δs::SmoothedFIs)
print(io, "$(Δs.δ)ε")
end

Expand Down
31 changes: 22 additions & 9 deletions src/discrete_randomness.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
## Rules for univariate uniparameter discrete distributions
struct SingleSidedStrategy end
struct TwoSidedStrategy end

new_Δs_strategy(Δs) = SingleSidedStrategy()

"""
δtoΔs(d, val, δ, Δs::AbstractFIs)
Given the parameter `val` of a distribution `d` and an infinitesimal change `δ`,
return the discrete change in the output, with a similar representation to `Δs`.
"""
function δtoΔs(d::Geometric, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
δtoΔs(d, val, δ, Δs) = δtoΔs(d, val, δ, Δs, new_Δs_strategy(Δs))
δtoΔs(d, val, δ, Δs, ::SingleSidedStrategy) = _δtoΔs(d, val, δ, Δs)
function δtoΔs(d, val, δ, Δs, ::TwoSidedStrategy)
Δs1 = _δtoΔs(d, val, δ, Δs)
Δs2 = _δtoΔs(d, val, -δ, Δs)
return combine((scale(Δs1, 0.5), scale(Δs2, -0.5)))
end

## Rules for univariate uniparameter discrete distributions

function _δtoΔs(d::Geometric, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
p = succprob(d)
if δ > 0
return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) :
Expand All @@ -18,7 +31,7 @@ function δtoΔs(d::Geometric, val::V, δ::Real, Δs::AbstractFIs) where {V <: S
end
end

function δtoΔs(d::Bernoulli, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
function _δtoΔs(d::Bernoulli, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
p = succprob(d)
if δ > 0
return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p))
Expand All @@ -29,7 +42,7 @@ function δtoΔs(d::Bernoulli, val::V, δ::Real, Δs::AbstractFIs) where {V <: S
end
end

function δtoΔs(d::Binomial, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
function _δtoΔs(d::Binomial, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
p = succprob(d)
n = ntrials(d)
if δ > 0
Expand All @@ -42,7 +55,7 @@ function δtoΔs(d::Binomial, val::V, δ::Real, Δs::AbstractFIs) where {V <: Si
end
end

function δtoΔs(d::Poisson, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
function _δtoΔs(d::Poisson, val::V, δ::Real, Δs::AbstractFIs) where {V <: Signed}
p = mean(d) # rate
if δ > 0
return similar_new(Δs, 1, δ)
Expand Down Expand Up @@ -72,7 +85,7 @@ for (dist, i) in [(:Geometric, :1), (:Bernoulli, :1), (:Binomial, :2), (:Poisson
alt_val = quantile(alt_d, rand(RNG) * (high - low) + low)
convert(Signed, alt_val - val)
end
Δs2 = map(map_func, st.Δs)
Δs2 = map(map_func, st.Δs; deriv = δ -> smoothed_delta(d, val, δ), out_rep = val)

StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this
end
Expand Down Expand Up @@ -134,7 +147,7 @@ end

### Rule for Categorical variable

function δtoΔs(d::Categorical, val::V, δs, Δs::AbstractFIs) where {V <: Signed}
function _δtoΔs(d::Categorical, val::V, δs, Δs::AbstractFIs) where {V <: Signed}
p = params(d)[1]
left_sum = sum(δs[1:(val - 1)], init = zero(V))
right_sum = -sum(δs[(val + 1):end], init = zero(V))
Expand Down Expand Up @@ -190,13 +203,13 @@ function Base.rand(rng::AbstractRNG,

low = cdf(d, val - 1)
high = cdf(d, val)
Δs_coupled = couple(Δs_all; rep = Δs_rep) # TODO: again, there are possible allocations here
Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here

function map_func(Δ)
alt_val = quantile(Categorical(p .+ Δ), rand(RNG) * (high - low) + low)
convert(Signed, alt_val - val)
end
Δs2 = map(map_func, Δs_coupled)
Δs2 = map(map_func, Δs_coupled; deriv = δ -> smoothed_delta(d, val, δ), out_rep = val)

StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs_rep))
end
5 changes: 5 additions & 0 deletions src/finite_infinitesimals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ function map_Δs end
function Base.map(f, Δs::AbstractFIs; kwargs...)
StochasticAD.map_Δs((Δs, _) -> f(Δs), Δs; kwargs...)
end

function new_Δs_strategy end

# Currently only supported / thought through for SmoothedFIs.
function scale end
Loading

0 comments on commit 1c4395b

Please sign in to comment.