Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eti to API #39

Merged
merged 7 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ default_summary_stats
summarize
```

## General statistics
## Credible intervals

```@docs
hdi
hdi!
eti
eti!
```

## LOO and WAIC
Expand Down
5 changes: 4 additions & 1 deletion src/PosteriorStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ export ModelComparisonResult, compare
export SummaryStats, summarize
export default_diagnostics, default_stats, default_summary_stats

# Credible intervals
export eti, eti!, hdi, hdi!

# Others
export hdi, hdi!, loo_pit, r2_score
export loo_pit, r2_score

const DEFAULT_INTERVAL_PROB = 0.94
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)
Expand Down
89 changes: 87 additions & 2 deletions src/eti.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,97 @@
function eti(x::AbstractVecOrMat{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB)
"""
eti(samples::AbstractVecOrMat{<:Real}; [prob, kwargs...]) -> IntervalSets.ClosedInterval
eti(samples::AbstractArray{<:Real}; [prob, kwargs...]) -> Array{<:IntervalSets.ClosedInterval}

Estimate the equal-tailed interval (ETI) of `samples` for the probability `prob`.

The ETI of a given probability is the credible interval wih the property that the
probability of being below the interval is equal to the probability of being above it.
That is, it is defined by the `(1-prob)/2` and `1 - (1-prob)/2` quantiles of the samples.

See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref).

# Arguments
- `samples`: an array of shape `(draws[, chains[, params...]])`. If multiple parameters are
present

# Keywords
- `prob`: the probability mass to be contained in the ETI. Default is
`$(DEFAULT_INTERVAL_PROB)`.
- `kwargs`: remaining keywords are passed to `Statistics.quantile`.

# Returns
- `intervals`: If `samples` is a vector or matrix, then a single
`IntervalSets.ClosedInterval` is returned. Otherwise, an array with the shape
`(params...,)`, is returned, containing a marginal ETI for each parameter.

!!! note
Any default value of `prob` is arbitrary. The default value of
`prob=$(DEFAULT_INTERVAL_PROB)` instead of a more common default like `prob=0.95` is
chosen to reminder the user of this arbitrariness.

# Examples

Here we calculate the 83% ETI for a normal random variable:

```jldoctest eti; setup = :(using Random; Random.seed!(78))
julia> x = randn(2_000);

julia> eti(x; prob=0.83)
-1.3740585250299766 .. 1.2860771129421198
```

We can also calculate the ETI for a 3-dimensional array of samples:

```jldoctest eti; setup = :(using Random; Random.seed!(67))
julia> x = randn(1_000, 1, 1) .+ reshape(0:5:10, 1, 1, :);

julia> eti(x)
3-element Vector{IntervalSets.ClosedInterval{Float64}}:
-1.951006825019686 .. 1.9011666217153793
3.048993174980314 .. 6.9011666217153795
8.048993174980314 .. 11.90116662171538
```
"""
function eti(
x::AbstractArray{<:Real};
prob::Real=DEFAULT_INTERVAL_PROB,
sorted::Bool=false,
kwargs...,
)
return eti!(sorted ? x : _copymutable(x); prob, sorted, kwargs...)
end

"""
eti!(samples::AbstractArray{<:Real}; [prob, kwargs...])

A version of [`eti`](@ref) that partially sorts `samples` in-place while computing the ETI.

See also: [`eti`](@ref), [`hdi`](@ref), [`hdi!`](@ref).
"""
function eti!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB, kwargs...)
ndims(x) > 0 ||
throw(ArgumentError("ETI cannot be computed for a 0-dimensional array."))
0 < prob < 1 || throw(DomainError(prob, "ETI `prob` must be in the range `(0, 1)`."))
isempty(x) && throw(ArgumentError("ETI cannot be computed for an empty array."))
return _eti!(x, prob; kwargs...)
end

function _eti!(x::AbstractVecOrMat{<:Real}, prob::Real; kwargs...)
if any(isnan, x)
T = float(promote_type(eltype(x), typeof(prob)))
lower = upper = T(NaN)
else
alpha = (1 - prob) / 2
lower, upper = Statistics.quantile(vec(x), (alpha, 1 - alpha))
lower, upper = Statistics.quantile!(vec(x), (alpha, 1 - alpha); kwargs...)
end
return IntervalSets.ClosedInterval(lower, upper)
end
function _eti!(x::AbstractArray, prob::Real; kwargs...)
axes_out = _param_axes(x)
T = float(promote_type(eltype(x), typeof(prob)))
interval = similar(x, IntervalSets.ClosedInterval{T}, axes_out)
for (i, x_slice) in zip(eachindex(interval), _eachparam(x))
interval[i] = _eti!(x_slice, prob; kwargs...)
end
return interval
end
6 changes: 5 additions & 1 deletion src/hdi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The HDI is the minimum width Bayesian credible interval (BCI). That is, it is th
possible interval containing `(100*prob)`% of the probability mass.[^Hyndman1996]
This implementation uses the algorithm of [^ChenShao1999].

See also: [`hdi!`](@ref), [`eti`](@ref), [`eti!`](@ref).

# Arguments
- `samples`: an array of shape `(draws[, chains[, params...]])`. If multiple parameters are
present
Expand Down Expand Up @@ -67,7 +69,9 @@ end
"""
hdi!(samples::AbstractArray{<:Real}; [prob, sorted])

A version of [`hdi`](@ref) that sorts `samples` in-place while computing the HDI.
A version of [`hdi`](@ref) that partially sorts `samples` in-place while computing the HDI.

See also: [`hdi`](@ref), [`eti`](@ref), [`eti!`](@ref).
"""
function hdi!(
x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB, sorted::Bool=false
Expand Down
17 changes: 8 additions & 9 deletions src/summarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ Compute the summary stats focusing on `Statistics.median`:
```jldoctest summarize
julia> summarize(x, default_summary_stats(median)...; var_names=[:a, :b, :c])
SummaryStats
median mad eti_94% mcse_median ess_tail ess_median rhat
a 0.004 0.978 -0.0738 .. 0.0731 0.020 3567 3336 1.00
b 10.02 0.995 9.93 .. 10.1 0.023 3841 3787 1.00
c 19.99 0.979 19.9 .. 20.0 0.020 3892 3829 1.00
median mad eti_94% mcse_median ess_tail ess_median rhat
a 0.004 0.978 -1.83 .. 1.89 0.020 3567 3336 1.00
b 10.02 0.995 8.17 .. 11.9 0.023 3841 3787 1.00
c 19.99 0.979 18.1 .. 21.9 0.020 3892 3829 1.00
```

Compute multiple quantiles simultaneously:
Expand Down Expand Up @@ -313,15 +313,14 @@ end
Default statistics to be computed with [`summarize`](@ref).

The value of `focus` determines the statistics to be returned:
- `Statistics.mean`: `mean`, `std`, `hdi_3%`, `hdi_97%`
- `Statistics.median`: `median`, `mad`, `eti_3%`, `eti_97%`
- `Statistics.mean`: `mean`, `std`, `hdi_94%`
- `Statistics.median`: `median`, `mad`, `eti_94%`

If `prob_interval` is set to a different value than the default, then different HDI and ETI
statistics are computed accordingly. [`hdi`](@ref) refers to the highest-density interval,
while `eti` refers to the equal-tailed interval (i.e. the credible interval computed from
symmetric quantiles).
while [`eti`](@ref) refers to the equal-tailed interval.

See also: [`hdi`](@ref)
See also: [`hdi`](@ref), [`eti`](@ref)
"""
function default_stats end
default_stats(; kwargs...) = default_stats(Statistics.mean; kwargs...)
Expand Down
68 changes: 56 additions & 12 deletions test/eti.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,87 @@
using IntervalSets
using OffsetArrays
using PosteriorStats
using Statistics
using Test

@testset "PosteriorStats.eti" begin
@testset "eti/eti!" begin
@testset "AbstractVecOrMat" begin
@testset for sz in (100, 1_000, (1_000, 2)),
prob in (0.7, 0.76, 0.8, 0.88),
T in (Float32, Float64)
T in (Float32, Float64, Int64)

S = Base.promote_eltype(one(T), prob)
n = prod(sz)
x = T <: Integer ? rand(T(1):T(30), sz) : randn(T, sz)
r = @inferred PosteriorStats.eti(x; prob)
S = Base.promote_eltype(one(T), prob)
x = T <: Integer ? rand(T(1):T(30), n) : randn(T, n)
r = @inferred eti(x; prob)
@test r isa ClosedInterval{S}
l, u = IntervalSets.endpoints(r)
frac_in_interval = mean(∈(r), x)
@test frac_in_interval ≈ prob
@test count(<(l), x) == count(>(u), x)
if !(T <: Integer)
l, u = IntervalSets.endpoints(r)
frac_in_interval = mean(∈(r), x)
@test frac_in_interval ≈ prob
@test count(<(l), x) == count(>(u), x)
end

@test eti!(copy(x); prob) == r
end
end

@testset "edge cases and errors" begin
@testset "NaNs returned if contains NaNs" begin
x = randn(1000)
x[3] = NaN
@test isequal(PosteriorStats.eti(x), NaN .. NaN)
@test isequal(eti(x), NaN .. NaN)
end

@testset "errors for empty array" begin
x = Float64[]
@test_throws ArgumentError PosteriorStats.eti(x)
@test_throws ArgumentError eti(x)
end

@testset "errors for 0-dimensional array" begin
x = fill(1.0)
@test_throws ArgumentError eti(x)
end

@testset "test errors when prob is not in (0, 1)" begin
x = randn(1_000)
@testset for prob in (0, 1, -0.1, 1.1, NaN)
@test_throws DomainError PosteriorStats.eti(x; prob)
@test_throws DomainError eti(x; prob)
end
end
end

@testset "AbstractArray consistent with AbstractVector" begin
@testset for sz in ((100, 2), (100, 2, 3), (100, 2, 3, 4)),
prob in (0.72, 0.81),
T in (Float32, Float64, Int64)

x = T <: Integer ? rand(T(1):T(30), sz) : randn(T, sz)
r = @inferred eti(x; prob)
if ndims(x) == 2
@test r isa ClosedInterval
@test r == eti(vec(x); prob)
else
@test r isa Array{<:ClosedInterval,ndims(x) - 2}
r_slices = dropdims(
mapslices(x -> eti(x; prob), x; dims=(1, 2)); dims=(1, 2)
)
@test r == r_slices
end

@test eti!(copy(x); prob) == r
end
end

@testset "OffsetArray" begin
@testset for n in (100, 1_000), prob in (0.732, 0.864), T in (Float32, Float64)
x = randn(T, (n, 2, 3, 4))
xoff = OffsetArray(x, (-1, 2, -3, 4))
r = eti(x; prob)
roff = @inferred eti(xoff; prob)
@test roff isa OffsetMatrix{<:ClosedInterval}
@test axes(roff) == (axes(xoff, 3), axes(xoff, 4))
@test collect(roff) == r
end
end
end
Loading