Skip to content

Commit

Permalink
Make unimodal HDI more efficient (#38)
Browse files Browse the repository at this point in the history
* Add copymutable utility

* Use partial sorting to speed up HDI

* Add sorted keyword to HDI

* Use PartialQuickSort to ensure the tails are internally sorted
  • Loading branch information
sethaxen authored Sep 8, 2024
1 parent c4e7c39 commit a666cd5
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
63 changes: 43 additions & 20 deletions src/hdi.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
hdi(samples::AbstractVecOrMat{<:Real}; [prob]) -> IntervalSets.ClosedInterval
hdi(samples::AbstractArray{<:Real}; [prob]) -> Array{<:IntervalSets.ClosedInterval}
hdi(samples::AbstractVecOrMat{<:Real}; [prob, sorted]) -> IntervalSets.ClosedInterval
hdi(samples::AbstractArray{<:Real}; [prob, sorted]) -> Array{<:IntervalSets.ClosedInterval}
Estimate the unimodal highest density interval (HDI) of `samples` for the probability `prob`.
Expand All @@ -15,6 +15,7 @@ This implementation uses the algorithm of [^ChenShao1999].
# Keywords
- `prob`: the probability mass to be contained in the HDI. Default is
`$(DEFAULT_INTERVAL_PROB)`.
- `sorted=false`: if `true`, the input samples are assumed to be sorted.
# Returns
- `intervals`: If `samples` is a vector or matrix, then a single
Expand Down Expand Up @@ -59,46 +60,68 @@ julia> hdi(x)
8.032604346765654 .. 11.900283185492153
```
"""
function hdi(x::AbstractArray{<:Real}; kwargs...)
xcopy = similar(x)
copyto!(xcopy, x)
return hdi!(xcopy; kwargs...)
function hdi(x::AbstractArray{<:Real}; sorted::Bool=false, kwargs...)
return hdi!(sorted ? x : _copymutable(x); sorted, kwargs...)
end

"""
hdi!(samples::AbstractArray{<:Real}; [prob])
hdi!(samples::AbstractArray{<:Real}; [prob, sorted])
A version of [`hdi`](@ref) that sorts `samples` in-place while computing the HDI.
"""
function hdi!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB)
function hdi!(
x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB, sorted::Bool=false
)
0 < prob < 1 || throw(DomainError(prob, "HDI `prob` must be in the range `(0, 1)`."))
return _hdi!(x, prob)
ndims(x) > 0 ||
throw(ArgumentError("HDI cannot be computed for a 0-dimensional array."))
isempty(x) && throw(ArgumentError("HDI cannot be computed for an empty array."))
return _hdi!(x, prob, sorted)
end

function _hdi!(x::AbstractVector{<:Real}, prob::Real)
isempty(x) && throw(ArgumentError("HDI cannot be computed for an empty array."))
function _hdi!(x::AbstractVector{<:Real}, prob::Real, sorted::Bool)
n = length(x)
interval_length = floor(Int, prob * n) + 1
if any(isnan, x) || interval_length == n
if any(isnan, x)
lower = upper = eltype(x)(NaN)
elseif interval_length == n && !sorted
lower, upper = extrema(x)
else
npoints_to_check = n - interval_length + 1
sort!(x)
lower_range = @views x[begin:(begin - 1 + npoints_to_check)]
upper_range = @views x[(begin - 1 + interval_length):end]
sorted || _hdi_sort!(x, interval_length, npoints_to_check)
lower_range = @view x[begin:(begin - 1 + npoints_to_check)]
upper_range = @view x[(begin - 1 + interval_length):end]
lower, upper = argmax(Base.splat(-), zip(lower_range, upper_range))
end
return IntervalSets.ClosedInterval(lower, upper)
end
_hdi!(x::AbstractMatrix{<:Real}, prob::Real) = _hdi!(vec(x), prob)
function _hdi!(x::AbstractArray{<:Real}, prob::Real)
ndims(x) > 0 ||
throw(ArgumentError("HDI cannot be computed for a 0-dimensional array."))
_hdi!(x::AbstractMatrix{<:Real}, prob::Real, sorted::Bool) = _hdi!(vec(x), prob, sorted)
function _hdi!(x::AbstractArray{<:Real}, prob::Real, sorted::Bool)
axes_out = _param_axes(x)
T = eltype(x)
interval = similar(x, IntervalSets.ClosedInterval{T}, axes_out)
for (i, x_slice) in zip(eachindex(interval), _eachparam(x))
interval[i] = _hdi!(x_slice, prob)
interval[i] = _hdi!(x_slice, prob, sorted)
end
return interval
end

function _hdi_sort!(x, interval_length, npoints_to_check)
if npoints_to_check < interval_length - 1
ifirst = firstindex(x)
iend = lastindex(x)
# first sort the lower tail in-place
sort!(x; alg=Base.Sort.PartialQuickSort(ifirst:(ifirst - 1 + npoints_to_check)))
# now sort the upper tail, avoiding modifying the lower tail
x_upper = @view x[(ifirst + npoints_to_check):iend]
sort!(
x_upper;
alg=Base.Sort.PartialQuickSort((
(interval_length - npoints_to_check):(iend - ifirst + 1 - npoints_to_check)
)),
)
else
sort!(x)
end
return x
end
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ function _assimilar(x::NamedTuple, y)
return z
end

# included since Base.copymutable is not public
function _copymutable(x::AbstractArray)
y = similar(x)
copyto!(y, x)
return y
end

function _skipmissing(x::AbstractArray)
Missing <: eltype(x) && return skipmissing(x)
return x
Expand Down
9 changes: 5 additions & 4 deletions test/hdi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ using Test
@testset "AbstractVector" begin
@testset for n in (10, 100, 1_000),
prob in (1 / n, 0.5, 0.73, 0.96, (n - 1 + 0.1) / n),
T in (Float32, Float64, Int64)
T in (Float32, Float64, Int64),
sorted in (true, false)

x = T <: Integer ? rand(T(1):T(30), n) : randn(T, n)
r = @inferred hdi(x; prob)
xsort = sort(x)
r = @inferred hdi(sorted ? xsort : x; prob, sorted)
@test r isa ClosedInterval{T}
l, u = IntervalSets.endpoints(r)
interval_length = floor(Int, prob * n) + 1
Expand All @@ -20,13 +22,12 @@ using Test
else
@test sum(x -> l x u, x) == interval_length
end
xsort = sort(x)
lind = 1:(n - interval_length + 1)
uind = interval_length:n
@assert all(collect(uind) .- collect(lind) .+ 1 .== interval_length)
@test minimum(xsort[uind] - xsort[lind]) u - l

@test hdi!(copy(x); prob) == r
@test hdi!(sorted ? xsort : x; prob, sorted) == r
end
end

Expand Down

0 comments on commit a666cd5

Please sign in to comment.