Skip to content

Ornstein-Uhlenbeck diffusion in k-dimensional space #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/continuous.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#OU process
# OU process
struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess
mean::T
volatility::T
Expand All @@ -13,19 +13,25 @@ var(model::OrnsteinUhlenbeckDiffusion) = (model.volatility^2) / (2 * model.rever

eq_dist(model::OrnsteinUhlenbeckDiffusion) = Normal(model.mean,sqrt(var(model)))

# These are for nested broadcasting
elmwiseadd(x, y) = x .+ y
elmwisesub(x, y) = x .- y
elmwisemul(x, y) = x .* y
elmwisediv(x, y) = x ./ y

function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real)
μ, σ, θ = process.mean, process.volatility, process.reversion
mean = @. exp(-(t - s) * θ) * (x_s - μ) + μ
var = similar(mean)
var .= ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ
# exp(-(t - s) * θ) * (x_s - μ) + μ
mean = elmwiseadd.(elmwisemul.(exp(-(t - s) * θ), elmwisesub.(x_s, μ)), μ)
var = ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ
return GaussianVariables(mean, var)
end

function backward(process::OrnsteinUhlenbeckDiffusion, x_t::AbstractArray, s::Real, t::Real)
μ, σ, θ = process.mean, process.volatility, process.reversion
mean = @. exp((t - s) * θ) * (x_t - μ) + μ
var = similar(mean)
var .= -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ
# @. exp((t - s) * θ) * (x_t - μ) + μ
mean = elmwiseadd.(elmwisemul.(exp((t - s) * θ), elmwisesub.(x_t, μ)), μ)
var = -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ
return (μ = mean, σ² = var)
end

Expand Down
10 changes: 10 additions & 0 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ function standardloss(
return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t))
end

function standardloss(
p::OrnsteinUhlenbeckDiffusion,
t::Union{Real,AbstractVector{<:Real}},
x̂::AbstractArray{<: SVector}, x::AbstractArray{<: SVector};
scaler=defaultscaler)
loss(x̂, x) = norm.(x̂ .- x).^2
# ugly syntax but scaler.(p, t) is not differentiable with Zygote.jl for some reason
return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t))
end

defaultscaler(p::RotationDiffusion, t::Real) = sqrt(1 - exp(-t * p.rate * 5))

function standardloss(
Expand Down
13 changes: 7 additions & 6 deletions src/randomvariable.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# Random Variables
# ----------------

struct GaussianVariables{T, A <: AbstractArray{T}}
struct GaussianVariables{A, B}
# μ and σ² must have the same size
μ::A # mean
σ²::A # variance
μ::A # mean (array)
σ²::B # variance (scalar)
end

Base.size(X::GaussianVariables) = size(X.μ)

sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = randn(rng, T, size(X)) .* .√X.σ² .+ X.μ
sample(rng::AbstractRNG, X::GaussianVariables) =
elmwisemul.(randn(rng, eltype(X.μ), size(X)), √X.σ²) .+ X.μ

function combine(X::GaussianVariables, lik)
σ² = @. inv(inv(X.σ²) + inv(lik.σ²))
μ = @. σ² * (X.μ / X.σ² + lik.μ / lik.σ²)
σ² = inv(inv(X.σ²) + inv(lik.σ²))
μ = elmwisemul.(σ², elmwisediv.(X.μ, X.σ²) .+ elmwisediv.(lik.μ, lik.σ²))
return GaussianVariables(μ, σ²)
end

Expand Down
36 changes: 30 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ end
x = one(QuatRotation{Float32})
t = 0.29999998f0
@test sampleforward(diffusion, t, [x]) isa Vector

# three-dimensional diffusion
diffusion = OrnsteinUhlenbeckDiffusion(0.0)
x_0 = fill(zero(SVector{3, Float64}), 2)
x_t = sampleforward(diffusion, 1.0, x_0)
@test x_t isa typeof(x_0)
@test size(x_t) == size(x_0)
end

@testset "Discrete Diffusions" begin
Expand Down Expand Up @@ -175,6 +182,12 @@ end
x = samplebackward((x, t) -> x + randn(size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t)
@test size(x) == size(x_t)
@test x isa Matrix

process = OrnsteinUhlenbeckDiffusion(0.0)
x_t = randn(SVector{3, Float64}, 4, 10)
x = samplebackward((x, t) -> x + randn(eltype(x), size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t)
@test size(x) == size(x_t)
@test x isa Matrix
end

@testset "Masked Diffusion" begin
Expand Down Expand Up @@ -244,23 +257,34 @@ end
end

@testset "Loss" begin
p = OrnsteinUhlenbeckDiffusion(0.0, 1.0, 0.5)
x_0 = randn(5, 10)
p = OrnsteinUhlenbeckDiffusion(0.0)
x_0 = zeros(5, 10)
t = rand(10)
@test standardloss(p, t, x_0, x_0) == 0
x = rand(5, 10)
@test standardloss(p, t, x, x_0) > 0

# unmasked elements don't contribute to the loss
x = copy(x_0)
m = x_0 .< 0
m = rand(size(x)...) .< 0.5
x[.!m] .= 1
x_0 = mask(x_0, m)
x[.!m] .= 0
@test standardloss(p, t, x, x_0) == 0
@test standardloss(p, t, x, parent(x_0)) > 0

# but masked elements do
x[m] .= 0
p = OrnsteinUhlenbeckDiffusion(0.0)
x_0 = fill(zero(SVector{3, Float64}), 10)
t = rand(10)
@test standardloss(p, t, x_0, x_0) == 0
x = [rand(SVector{3, Float64}) for _ in eachindex(x_0)]
@test standardloss(p, t, x, x_0) > 0

x = copy(x_0)
m = rand(size(x)...) .< 0.5
x[.!m] .= (ones(SVector{3, Float64}),)
x_0 = mask(x_0, m)
@test standardloss(p, t, x, x_0) == 0
@test standardloss(p, t, x, parent(x_0)) > 0
end

@testset "Autodiff" begin
Expand Down