diff --git a/Project.toml b/Project.toml index 8cffb1b..a5bacbf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PosteriorStats" uuid = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" authors = ["Seth Axen and contributors"] -version = "0.1.4" +version = "0.2.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -14,6 +14,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -21,6 +22,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -39,6 +41,7 @@ MCMCDiagnosticTools = "0.3.4" Markdown = "1.6" OffsetArrays = "1" Optim = "1" +OrderedCollections = "1" PSIS = "0.9.1" PrettyTables = "2.1, 2.2" Printf = "1.6" @@ -47,6 +50,7 @@ Random = "1.6" Setfield = "1" Statistics = "1.6" StatsBase = "0.32, 0.33, 0.34" +TableOperations = "1" TableTraits = "0.4, 1" Tables = "1" julia = "1.6" diff --git a/src/PosteriorStats.jl b/src/PosteriorStats.jl index f4bab41..0f49a15 100644 --- a/src/PosteriorStats.jl +++ b/src/PosteriorStats.jl @@ -10,6 +10,7 @@ using LogExpFunctions: LogExpFunctions using Markdown: @doc_str using MCMCDiagnosticTools: MCMCDiagnosticTools using Optim: Optim +using OrderedCollections: OrderedCollections using PrettyTables: PrettyTables using Printf: Printf using PSIS: PSIS, PSISResult, psis, psis! diff --git a/src/summarize.jl b/src/summarize.jl index a4c4c70..dd5a86e 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -1,42 +1,102 @@ """ -$(SIGNATURES) +$(TYPEDEF) A container for a column table of values computed by [`summarize`](@ref). -This object implements the Tables and TableTraits interfaces and has a custom `show` method. +This object implements the Tables and TableTraits column table interfaces. It has a custom +`show` method. -$(FIELDS) +`SummaryStats` behaves like an `OrderedDict` of columns, where the columns can be accessed +using either `Symbol`s or a 1-based integer index. + +$(TYPEDFIELDS) + + SummaryStats([name::String,] data[, parameter_names]) + SummaryStats(data[, parameter_names]; name::String="SummaryStats") + +Construct a `SummaryStats` from tabular `data` with optional stats `name` and `param_names`. + +`data` must not contain a column `:parameter`, as this is reserved for the parameter names, +which are always in the first column. """ -struct SummaryStats{D<:NamedTuple} +struct SummaryStats{D,V<:AbstractVector} "The name of the collection of summary statistics, used as the table title in display." name::String - """The summary statistics for each parameter, with an optional first column `parameter` - containing the parameter names.""" + """The summary statistics for each parameter. It must implement the Tables interface.""" data::D + "Names of the parameters" + parameter_names::V + function SummaryStats(name::String, data, parameter_names::V) where {V} + coltable = Tables.columns(data) + :parameter ∈ Tables.columnnames(coltable) && + throw(ArgumentError("Column `:parameter` is reserved for parameter names.")) + length(parameter_names) == Tables.rowcount(data) || throw( + DimensionMismatch( + "length $(length(parameter_names)) of `parameter_names` does not match number of rows $(Tables.rowcount(data)) in `data`.", + ), + ) + return new{typeof(coltable),V}(name, coltable, parameter_names) + end +end +function SummaryStats( + data, + parameter_names::AbstractVector=Base.OneTo(Tables.rowcount(data)); + name::String="SummaryStats", +) + return SummaryStats(name, data, parameter_names) end -function SummaryStats(data::NamedTuple; name::String="SummaryStats") - n = length(first(data)) - return SummaryStats(name, merge((parameter=1:n,), data)) +function SummaryStats(name::String, data) + return SummaryStats(name, data, Base.OneTo(Tables.rowcount(data))) +end + +function _ordereddict(stats::SummaryStats) + return OrderedCollections.OrderedDict( + k => Tables.getcolumn(stats, k) for k in Tables.columnnames(stats) + ) end # forward key interfaces from its parent Base.parent(stats::SummaryStats) = getfield(stats, :data) -Base.keys(stats::SummaryStats) = keys(parent(stats)) -Base.haskey(stats::SummaryStats, nm::Symbol) = haskey(parent(stats), nm) -Base.length(stats::SummaryStats) = length(parent(stats)) -Base.getindex(stats::SummaryStats, i::Int) = getindex(parent(stats), i) -Base.getindex(stats::SummaryStats, nm::Symbol) = getindex(parent(stats), nm) -function Base.iterate(stats::SummaryStats, i::Int=firstindex(parent(stats))) - return iterate(parent(stats), i) +Base.keys(stats::SummaryStats) = map(Symbol, Tables.columnnames(stats)) +Base.haskey(stats::SummaryStats, nm::Symbol) = nm ∈ keys(stats) +Base.length(stats::SummaryStats) = length(parent(stats)) + 1 +Base.getindex(stats::SummaryStats, i::Union{Int,Symbol}) = Tables.getcolumn(stats, i) +function Base.iterate(stats::SummaryStats) + ncols = length(stats) + return stats.parameter_names, (2, ncols) end -function Base.merge(stats::SummaryStats, other_stats::SummaryStats...) - return SummaryStats(stats.name, merge(parent(stats), map(parent, other_stats)...)) +function Base.iterate(stats::SummaryStats, (i, ncols)::NTuple{2,Int}) + i > ncols && return nothing + return Tables.getcolumn(stats, i), (i + 1, ncols) end -function Base.isequal(stats::SummaryStats, other_stats::SummaryStats) - return isequal(parent(stats), parent(other_stats)) +function Base.merge( + stats::SummaryStats{<:NamedTuple}, other_stats::SummaryStats{<:NamedTuple}... +) + isempty(other_stats) && return stats + stats_all = (stats, other_stats...) + stats_last = last(stats_all) + return SummaryStats( + stats_last.name, merge(map(parent, stats_all)...), stats_last.parameter_names + ) end -function Base.:(==)(stats::SummaryStats, other_stats::SummaryStats) - return (parent(stats) == parent(other_stats)) +function Base.merge(stats::SummaryStats, other_stats::SummaryStats...) + isempty(other_stats) && return stats + stats_all = (stats, other_stats...) + data_merged = merge(map(_ordereddict, stats_all)...) + parameter_names = pop!(data_merged, :parameter) + return SummaryStats(last(stats_all).name, data_merged, parameter_names) +end +for f in (:(==), :isequal) + @eval begin + function Base.$(f)(stats::SummaryStats, other_stats::SummaryStats) + colnames1 = Tables.columnnames(stats) + colnames2 = Tables.columnnames(other_stats) + vals1 = (Tables.getcolumn(stats, k) for k in colnames1) + vals2 = (Tables.getcolumn(other_stats, k) for k in colnames2) + return all(Base.splat($f), zip(colnames1, colnames2)) && + all(Base.splat($f), zip(vals1, vals2)) + end + end end #### custom tabular show methods @@ -49,7 +109,7 @@ function Base.show(io::IO, mime::MIME"text/html", stats::SummaryStats; kwargs... end function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...) - data = NamedTuple{eachindex(stats)[2:end]}(parent(stats)) + data = parent(stats) rhat_formatter = _prettytables_rhat_formatter(data) extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,) return _show_prettytable( @@ -57,7 +117,7 @@ function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...) mime, data; title=stats.name, - row_labels=parent(stats).parameter, + row_labels=Tables.getcolumn(stats, :parameter), extra_formatters, kwargs..., ) @@ -68,12 +128,29 @@ end Tables.istable(::Type{<:SummaryStats}) = true Tables.columnaccess(::Type{<:SummaryStats}) = true Tables.columns(s::SummaryStats) = s -Tables.columnnames(s::SummaryStats) = Tables.columnnames(parent(s)) -Tables.getcolumn(s::SummaryStats, i::Int) = Tables.getcolumn(parent(s), i) -Tables.getcolumn(s::SummaryStats, nm::Symbol) = Tables.getcolumn(parent(s), nm) -Tables.rowaccess(::Type{<:SummaryStats}) = true -Tables.rows(s::SummaryStats) = Tables.rows(parent(s)) -Tables.schema(s::SummaryStats) = Tables.schema(parent(s)) +function Tables.columnnames(s::SummaryStats) + data_cols = Tables.columnnames(parent(s)) + data_cols isa Tuple && return (:parameter, data_cols...) + return collect(Iterators.flatten(((:parameter,), data_cols))) +end +function Tables.getcolumn(stats::SummaryStats, i::Int) + i == 1 && return stats.parameter_names + return Tables.getcolumn(parent(stats), i - 1) +end +function Tables.getcolumn(stats::SummaryStats, nm::Symbol) + nm === :parameter && return stats.parameter_names + return Tables.getcolumn(parent(stats), nm) +end +function Tables.schema(s::SummaryStats) + data_schema = Tables.schema(parent(s)) + data_schema === nothing && return nothing + T = eltype(s.parameter_names) + if data_schema isa Tables.Schema{Nothing,Nothing} + return Tables.Schema([:parameter; data_schema.names], [T; data_schema.types]) + else + return Tables.Schema((:parameter, data_schema.names...), (T, data_schema.types...)) + end +end IteratorInterfaceExtensions.isiterable(::SummaryStats) = true function IteratorInterfaceExtensions.getiterator(s::SummaryStats) @@ -194,8 +271,7 @@ Compute the summary statistics in `stats_funs` on each param in `data`, with siz fnames = map(first, names_and_funs) _check_function_names(fnames) funs = map(last, names_and_funs) - nt = merge((; parameter=var_names), _summarize(data, funs, fnames)...) - return SummaryStats(name, nt) + return SummaryStats(name, _summarize(data, funs, fnames), var_names) end function _check_function_names(fnames) @@ -295,14 +371,15 @@ function _prob_interval_to_strings(interval_type, prob; digits=2) end end -function _summarize(data::AbstractArray{<:Any,3}, funs, fun_names) - return map(fun_names, funs) do fname, f +# aggressive constprop allows summarize to be type-inferrable when called by +# another function + +@constprop :aggressive function _summarize(data::AbstractArray{<:Any,3}, funs, fun_names) + return merge(map(fun_names, funs) do fname, f return _map_over_params(fname, f, data) - end + end...) end -# aggressive constprop allows summarize to be type-inferrable when called by -# another function @constprop :aggressive function _map_over_params(fname, f, data) vals = _map_paramslices(f, data) return _namedtuple_of_vals(f, fname, vals) diff --git a/src/utils.jl b/src/utils.jl index 16fdc9c..5ede522 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -257,13 +257,15 @@ function ft_printf_sigdigits_matching_se( end function _prettytables_rhat_formatter(data) - cols = findall(x -> x === :rhat, keys(data)) + cols = findall(x -> x === :rhat, Tables.columnnames(data)) isempty(cols) && return nothing return PrettyTables.ft_printf("%.2f", cols) end function _prettytables_integer_formatter(data) - cols = findall(v -> eltype(v) <: Integer, values(data)) + sch = Tables.schema(data) + sch === nothing && return nothing + cols = findall(t -> t <: Integer, sch.types) isempty(cols) && return nothing return PrettyTables.ft_printf("%d", cols) end @@ -272,20 +274,24 @@ end # see https://ronisbr.github.io/PrettyTables.jl/stable/man/formatters/ function _default_prettytables_formatters(data; sigdigits_se=2, sigdigits_default=3) formatters = [] - for (i, k) in enumerate(keys(data)) + col_names = Tables.columnnames(data) + for (i, k) in enumerate(col_names) for mcse_key in (Symbol("mcse_$k"), Symbol("$(k)_mcse")) if haskey(data, mcse_key) - push!(formatters, ft_printf_sigdigits_matching_se(data[mcse_key], [i])) + push!( + formatters, + ft_printf_sigdigits_matching_se(Tables.getcolumn(data, mcse_key), [i]), + ) continue end end end - mcse_cols = findall(keys(data)) do k + mcse_cols = findall(col_names) do k s = string(k) return startswith(s, "mcse_") || endswith(s, "_mcse") end isempty(mcse_cols) || push!(formatters, ft_printf_sigdigits(sigdigits_se, mcse_cols)) - ess_cols = findall(_is_ess_label, keys(data)) + ess_cols = findall(_is_ess_label, col_names) isempty(ess_cols) || push!(formatters, PrettyTables.ft_printf("%d", ess_cols)) ft_integer = _prettytables_integer_formatter(data) ft_integer === nothing || push!(formatters, ft_integer) @@ -300,12 +306,10 @@ function _show_prettytable( extra_formatters..., _default_prettytables_formatters(data; sigdigits_se, sigdigits_default)..., ) - alignment = fill(:r, length(data)) - for (i, v) in enumerate(values(data)) - if !(eltype(v) <: Real) - alignment[i] = :l - end - end + col_names = Tables.columnnames(data) + alignment = [ + eltype(Tables.getcolumn(data, col_name)) <: Real ? :r : :l for col_name in col_names + ] kwargs_new = merge( ( show_subheader=false, @@ -331,13 +335,16 @@ function _show_prettytable( newline_at_end=false, kwargs..., ) - alignment_anchor_regex = Dict{Int,Vector{Regex}}( - i => [r"\.", r"e", r"^NaN$", r"Inf$"] for (i, (k, v)) in enumerate(pairs(data)) if - (eltype(v) <: Real && !(eltype(v) <: Integer) && !_is_ess_label(k)) - ) + alignment_anchor_regex = Dict{Int,Vector{Regex}}() + for (i, k) in enumerate(Tables.columnnames(data)) + v = Tables.getcolumn(data, k) + if eltype(v) <: Real && !(eltype(v) <: Integer) && !_is_ess_label(k) + alignment_anchor_regex[i] = [r"\.", r"e", r"^NaN$", r"Inf$"] + end + end alignment_anchor_fallback = :r alignment_anchor_fallback_override = Dict( - i => :r for (i, k) in enumerate(keys(data)) if _is_ess_label(k) + i => :r for (i, k) in enumerate(Tables.columnnames(data)) if _is_ess_label(k) ) return _show_prettytable( io, diff --git a/test/summarize.jl b/test/summarize.jl index 357433d..c0bcd13 100644 --- a/test/summarize.jl +++ b/test/summarize.jl @@ -1,5 +1,6 @@ using IteratorInterfaceExtensions using MCMCDiagnosticTools +using OrderedCollections using PosteriorStats using Statistics using StatsBase @@ -23,63 +24,77 @@ _mean_and_std(x) = (mean=mean(x), std=std(x)) @testset "summary statistics" begin @testset "SummaryStats" begin - data = ( - parameter=["a", "bb", "ccc", "d", "e"], - est=randn(5), - mcse_est=randn(5), - rhat=rand(5), - ess=rand(5), - ) + parameter_names = ["a", "bb", "ccc", "d", "e"] + data = (est=randn(5), mcse_est=rand(5), rhat=rand(5), ess=rand(5)) - stats = @inferred SummaryStats(data; name="Stats") + @inferred SummaryStats(data; name="Stats") + stats = @inferred SummaryStats(data, parameter_names; name="Stats") @testset "basic interfaces" begin @test parent(stats) === data @test stats.name == "Stats" @test SummaryStats("MoreStats", data).name == "MoreStats" @test SummaryStats(data; name="MoreStats").name == "MoreStats" - @test keys(stats) == keys(data) + @test keys(stats) == (:parameter, keys(data)...) for k in keys(stats) - @test haskey(stats, k) == haskey(data, k) - @test getindex(stats, k) == getindex(data, k) + @test haskey(stats, k) + if k === :parameter + @test getindex(stats, k) == parameter_names + else + @test getindex(stats, k) == getindex(data, k) + end end @test !haskey(stats, :foo) - @test length(stats) == length(data) + @test length(stats) == length(data) + 1 + @test getindex(stats, 1) == parameter_names for i in 1:length(data) - @test getindex(stats, i) == getindex(data, i) + @test stats[i + 1] == data[i] end - @test Base.iterate(stats) == Base.iterate(data) - @test Base.iterate(stats, 2) == Base.iterate(data, 2) + @test Base.iterate(stats) == (parameter_names, (2, length(stats))) + @test Base.iterate(stats, (2, length(stats))) == (stats[2], (3, length(stats))) data_copy1 = deepcopy(data) - stats2 = SummaryStats(data_copy1) + stats2 = SummaryStats(data_copy1, parameter_names) @test stats2 == stats @test isequal(stats2, stats) data_copy2 = deepcopy(data) - stats3 = SummaryStats(data_copy2; name="Stats") - @test stats3 == stats2 - @test isequal(stats3, stats2) - stats3[:parameter][1] = "foo" + parameter_names2 = copy(parameter_names) + parameter_names2[1] = "foo" + stats3 = SummaryStats(data_copy2, parameter_names2; name="Stats") @test stats3 != stats2 @test !isequal(stats3, stats2) - stats3[:parameter][1] = "a" + stats3 = SummaryStats(data_copy2, parameter_names; name="Stats") stats3[:est][2] = NaN @test stats3 != stats2 @test !isequal(stats3, stats2) stats2[:est][2] = NaN @test stats3 != stats2 @test isequal(stats3, stats2) + end - stats4 = SummaryStats((; est=randn(5), est2=randn(5)); name="MoreStats") - @test parent(stats4).parameter == 1:5 - stats_merged1 = merge(stats, stats4) - @test stats_merged1.name == "Stats" - @test parent(stats_merged1) == merge(parent(stats), parent(stats4)) + @testset "merge" begin + stats_dict = SummaryStats( + OrderedDict(pairs(data)), parameter_names; name="Stats" + ) + @test merge(stats) === stats + @test merge(stats, stats) == stats + @test merge(stats_dict) === stats_dict + @test merge(stats_dict, stats_dict) == stats_dict + @test merge(stats, stats_dict) == stats + @test merge(stats_dict, stats) == stats_dict - stats_merged2 = merge(stats4, stats) - @test stats_merged2.name == "MoreStats" - @test parent(stats_merged2) == merge(parent(stats4), parent(stats)) + data2 = (ess=randn(5), rhat=rand(5), mcse_est=rand(5), est2=rand(5)) + stats2 = SummaryStats(data2, 1:5; name="Stats2") + stats2_dict = SummaryStats(OrderedDict(pairs(data2)), 1:5; name="Stats2") + for stats_a in (stats, stats_dict), stats_b in (stats2, stats2_dict) + @test merge(stats_a, stats_b) == + SummaryStats(merge(data, data2), stats_b.parameter_names) + @test merge(stats_a, stats_b).name == stats_b.name + @test merge(stats_b, stats_a) == + SummaryStats(merge(data2, data), stats_a.parameter_names) + @test merge(stats_b, stats_a).name == stats_a.name + end end @testset "Tables interface" begin @@ -88,14 +103,13 @@ _mean_and_std(x) = (mean=mean(x), std=std(x)) @test Tables.columns(stats) === stats @test Tables.columnnames(stats) == keys(stats) table = Tables.columntable(stats) - @test table == data + @test table == (; parameter=parameter_names, data...) for (i, k) in enumerate(Tables.columnnames(stats)) @test Tables.getcolumn(stats, i) == Tables.getcolumn(stats, k) end @test_throws ErrorException Tables.getcolumn(stats, :foo) - @test Tables.rowaccess(typeof(stats)) - @test Tables.rows(stats) == Tables.rows(parent(stats)) - @test Tables.schema(stats) == Tables.schema(parent(stats)) + @test !Tables.rowaccess(typeof(stats)) + @test Tables.schema(stats) == Tables.schema(Tables.columntable(stats)) end @testset "TableTraits interface" begin @@ -124,15 +138,15 @@ _mean_and_std(x) = (mean=mean(x), std=std(x)) end @testset "show" begin + parameter_names = ["a", "bb", "ccc", "d", "e"] data = ( - parameter=["a", "bb", "ccc", "d", "e"], est=[111.11, 1.2345e-6, 5.4321e8, Inf, NaN], mcse_est=[0.0012345, 5.432e-5, 2.1234e5, Inf, NaN], rhat=vcat(1.009, 1.011, 0.99, Inf, NaN), ess=vcat(312.45, 23.32, 1011.98, Inf, NaN), ess_bulk=vcat(9.2345, 876.321, 999.99, Inf, NaN), ) - stats = SummaryStats(data) + stats = SummaryStats(data, parameter_names) @test sprint(show, "text/plain", stats) == """ SummaryStats est mcse_est rhat ess ess_bulk @@ -156,12 +170,14 @@ _mean_and_std(x) = (mean=mean(x), std=std(x)) end @test stats1 isa SummaryStats @test getfield(stats1, :name) == "SummaryStats" - @test stats1 == SummaryStats(( - parameter=axes(x, 3), - mean=map(mean, eachslice(x; dims=3)), - std=map(std, eachslice(x; dims=3)), - median=map(median, eachslice(x; dims=3)), - )) + @test stats1 == SummaryStats( + ( + mean=map(mean, eachslice(x; dims=3)), + std=map(std, eachslice(x; dims=3)), + median=map(median, eachslice(x; dims=3)), + ), + axes(x, 3), + ) function _compute_stats(x) return summarize(x, (:mean, :std) => mean_and_std, :median => median) @@ -175,11 +191,10 @@ _mean_and_std(x) = (mean=mean(x), std=std(x)) stats3 = summarize(x, mean, std; var_names=["a", "b", "c"], name="Stats") @test getfield(stats3, :name) == "Stats" - @test stats3 == SummaryStats(( - parameter=["a", "b", "c"], - mean=map(mean, eachslice(x; dims=3)), - std=map(std, eachslice(x; dims=3)), - )) + @test stats3 == SummaryStats( + (mean=map(mean, eachslice(x; dims=3)), std=map(std, eachslice(x; dims=3))), + ["a", "b", "c"], + ) stats4 = summarize(x; var_names=["a", "b", "c"], name="Stats") @test getfield(stats4, :name) == "Stats"