diff --git a/Project.toml b/Project.toml index f0b2bd1..3436d6f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" @@ -33,6 +34,7 @@ DocStringExtensions = "0.8, 0.9" FiniteDifferences = "0.12" GLM = "1" IteratorInterfaceExtensions = "0.1.1, 1" +KernelDensity = "0.6" LogExpFunctions = "0.2.0, 0.3" MCMCDiagnosticTools = "0.3.4" OffsetArrays = "1" diff --git a/docs/src/api.md b/docs/src/api.md index c1b20c0..3f9cec2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -56,7 +56,14 @@ Stacking loo_pit ``` -### Utilities +## Density utilities + +```@docs +kde +bandwidth_silverman +``` + +## Utilities ```@docs PosteriorStats.smooth_data diff --git a/src/PosteriorStats.jl b/src/PosteriorStats.jl index f4bab41..4bd1060 100644 --- a/src/PosteriorStats.jl +++ b/src/PosteriorStats.jl @@ -5,6 +5,7 @@ using DataInterpolations: DataInterpolations using Distributions: Distributions using DocStringExtensions: FIELDS, FUNCTIONNAME, TYPEDEF, TYPEDFIELDS, SIGNATURES using IteratorInterfaceExtensions: IteratorInterfaceExtensions +using KernelDensity: KernelDensity using LinearAlgebra: mul!, norm using LogExpFunctions: LogExpFunctions using Markdown: @doc_str @@ -20,6 +21,9 @@ using StatsBase: StatsBase using Tables: Tables using TableTraits: TableTraits +# Density +export bandwidth_silverman, kde + # PSIS export PSIS, PSISResult, psis, psis! @@ -43,6 +47,7 @@ const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1) include("utils.jl") include("hdi.jl") +include("kde.jl") include("elpdresult.jl") include("loo.jl") include("waic.jl") diff --git a/src/kde.jl b/src/kde.jl new file mode 100644 index 0000000..325d2bf --- /dev/null +++ b/src/kde.jl @@ -0,0 +1,106 @@ +const STD_NORM_IQR = rationalize(1.34) + +""" + bandwidth_silverman(x; kwargs...) -> Real +""" +function bandwidth_silverman( + x::AbstractVector{<:Real}; alpha::Real=9//10, std::Real=Statistics.std(x) +) + n = length(x) + iqr = StatsBase.iqr(x) + quantile_width = iqr / STD_NORM_IQR + width = min(std, quantile_width) + T = typeof(one(width)) + return alpha * width * T(n)^(-1//5) +end + +function _padding_factor( + kernel::Distributions.ContinuousUnivariateDistribution, prob_tail::Real +) + return Int(cld(Distributions.cquantile(kernel, prob_tail / 2), Statistics.std(kernel))) +end + +function _kernel_with_bandwidth( + T::Type{<:Distributions.ContinuousUnivariateDistribution}, bw +) + d = T() + dcentered = (d - Statistics.median(d)) * bw / Statistics.std(d) + StatsBase.skewness(dcentered) ≈ 0 || throw(ArgumentError("Kernel must be symmetric.")) + return dcentered +end + +""" + kde(x; kwargs...) -> KernelDensity.UnivariateKDE + +Compute the univariate kernel density estimate of data `x`. + +# Arguments +- `x`: data array + +# Keyword arguments +- `bandwidth::Real`: bandwidth of the kernel. Defaults to [`bandwidth_silverman(x)`](@ref). +- `kernel::Type{<:Distributions.ContinuousUnivariateDistribution}`: type of kernel to build. + Defaults to `Normal`. +- `bound_correction::Bool`: whether to perform boundary correction. Defaults to `true`. + If `false`, the resulting truncated KDE is not normalized to 1. +- `npoints::Int`: number of points at which the resulting KDE is evaluated. Defaults to + `512`. +""" +function kde( + x::AbstractVector; + bandwidth::Real=bandwidth_silverman(x), + kernel=Distributions.Normal, + bound_correction::Bool=true, + npoints::Int=512, + pad_factor::Union{Real,Nothing}=nothing, +) + grid_size = max(npoints, 100) + grid_min, grid_max = extrema(x) + bin_width = (grid_max - grid_min) / grid_size + + if pad_factor === nothing + # work out how much padding to add to guarantee that extra density due to wraparound + # is negligible + prob_tail = 1e-3 + _kernel = _kernel_with_bandwidth(kernel, bandwidth) + _pad_factor = _padding_factor(_kernel, prob_tail) + elseif pad_factor < 0 + throw(DomainError(pad_factor, "Padding factor must be non-negative.")) + else + _pad_factor = pad_factor + end + # always pad by at least 1 bin on each side to ensure that the boundary passed to kde + # contains all data points (otherwise, they will be ignored) + grid_pad_size = 2 * max(1, Int(cld(_pad_factor * bandwidth, bin_width))) + npad_left = npad_right = grid_pad_size ÷ 2 + + # pad to avoid wraparound at the boundary + grid_min -= bin_width * npad_left + grid_max += bin_width * npad_right + grid_size += grid_pad_size + + # compute density + boundary = (grid_min + bin_width / 2, grid_max - bin_width / 2) + k = KernelDensity.kde(x; npoints=grid_size, boundary, bandwidth, kernel) + midpoints = k.x + density = k.density + + if bound_correction + nbin_reflect = min(npad_left, npad_right) + il = firstindex(density) + npad_left + ir = lastindex(density) - npad_right + # reflect density at the boundary (x_min, x_max) + density[range(il; length=nbin_reflect)] .+= @view density[range( + il - 1; step=-1, length=nbin_reflect + )] + density[range(ir; step=-1, length=nbin_reflect)] .+= @view density[range( + ir + 1; length=nbin_reflect + )] + end + + # remove padding + midpoints = midpoints[(begin + npad_left):(end - npad_right)] + density = density[(begin + npad_left):(end - npad_right)] + + return KernelDensity.UnivariateKDE(midpoints, density) +end