Skip to content

Commit

Permalink
Merge pull request #117 from fjebaker/fergus/wrappers
Browse files Browse the repository at this point in the history
feat: model wrappers and AutoCache
  • Loading branch information
fjebaker authored Jun 25, 2024
2 parents f85b322 + dc78403 commit c336468
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ include("composite-models.jl")

include("reflection.jl")

include("meta-models/wrappers.jl")
include("meta-models/table-models.jl")
include("meta-models/surrogate-models.jl")
include("meta-models/caching.jl")

include("poisson.jl")

Expand Down
5 changes: 4 additions & 1 deletion src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ end
)
invoke!(output, domain, model)
# perform additive normalisation
@. output *= model.K
K = normalisation(model)
@. output *= K
output
end
@inline function invokemodel!(
Expand All @@ -302,6 +303,8 @@ end
output
end

normalisation(model::AbstractSpectralModel{T,Additive}) where {T} = model.K

"""
allocate_model_output(model::AbstractSpectralModel, domain::AbstractVector)
Expand Down
89 changes: 89 additions & 0 deletions src/meta-models/caching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
mutable struct CacheEntry{T}
cache::Vector{T}
params::Vector{T}
domain_limits::Tuple{T,T}
size_of_element::Int

function CacheEntry(params::AbstractVector{<:Number})
T = eltype(params)
cache = zeros(T, 1)
new{T}(cache, params, (zero(T), zero(T)), sizeof(T))
end
end

struct AutoCache{M,T,K,C<:CacheEntry} <: AbstractModelWrapper{M,T,K}
model::M
cache::C
abstol::Float64
function AutoCache(
model::AbstractSpectralModel{T,K},
cache::CacheEntry,
abstol,
) where {T,K}
new{typeof(model),T,K,typeof(cache)}(model, cache, abstol)
end
end

function AutoCache(model::AbstractSpectralModel{T,K}; abstol = 1e-3) where {T,K}
params = [get_value.(parameter_tuple(model))...]
cache = CacheEntry(params)
AutoCache(model, cache, abstol)
end

function _reinterpret_dual(::Type, v::AbstractArray, n::Int)
needs_resize = n > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
resize!(v, n)
end
view(v, 1:n), needs_resize
end
function _reinterpret_dual(
DualType::Type{<:ForwardDiff.Dual},
v::AbstractArray{T},
n::Int,
) where {T}
n_elems = div(sizeof(DualType), sizeof(T)) * n
needs_resize = n_elems > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
resize!(v, n_elems)
end
reinterpret(DualType, view(v, 1:n_elems)), needs_resize
end

function invoke!(output, domain, model::AutoCache{M,T}) where {M,T}
D = promote_type(eltype(domain), T)

_new_params = parameter_tuple(model.model)
_new_limits = (first(domain), last(domain))

output_cache, out_resized = _reinterpret_dual(D, model.cache.cache, length(output))
param_cache, _ = _reinterpret_dual(D, model.cache.params, length(_new_params))

same_domain = model.cache.domain_limits == _new_limits

# if the parameter size has changed, need to rerun the model
if (!out_resized) && (model.cache.size_of_element == sizeof(D)) && (same_domain)
# if all parameters within some tolerance, then just return the cache
within_tolerance = all(zip(param_cache, _new_params)) do I
p, pm = I
abs((get_value(p) - get_value(pm)) / p) < model.abstol
end

if within_tolerance
@. output = output_cache
return output
end
end

model.cache.size_of_element = sizeof(D)
invoke!(output_cache, domain, model.model)
# update the auto cache infos
model.cache.domain_limits = _new_limits

@. param_cache = get_value(_new_params)
@. output = output_cache
end

export AutoCache
57 changes: 57 additions & 0 deletions src/meta-models/wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
abstract type AbstractModelWrapper{M,T,K} <: AbstractSpectralModel{T,K} end
First field of the struct must be `model`.
"""
abstract type AbstractModelWrapper{M<:AbstractSpectralModel,T,K} <:
AbstractSpectralModel{T,K} end

normalisation(model::AbstractModelWrapper{M,T,Additive}) where {M,T} =
normalisation(model.model)

function Reflection.get_closure_symbols(
M::Type{<:AbstractModelWrapper{Model}},
) where {Model}
# we ignore the `model` field, since that will be given by the constructor
(fieldnames(M)[2:end]..., Reflection.get_closure_symbols(Model)...)
end

Reflection.get_parameter_symbols(::Type{<:AbstractModelWrapper{M}}) where {M} =
Reflection.get_parameter_symbols(M)

function Reflection.closure_parameter_lenses(
M::Type{<:AbstractModelWrapper},
info::Reflection.ModelInfo,
)
num_closures = fieldcount(M) - 1 # ignore the `model` field

my_closures = map(info.closure_symbols[1:num_closures]) do s
:(getfield($(info.lens), $(Meta.quot(s))))
end
model_closures = map(info.closure_symbols[num_closures+1:end]) do s
:(getfield($(info.lens).model, $(Meta.quot(s))))
end
vcat(my_closures, model_closures)
end

function Reflection.parameter_lenses(
::Type{<:AbstractModelWrapper},
info::Reflection.ModelInfo,
)
map(info.symbols) do s
:(getfield($(info.lens).model, $(Meta.quot(s))))
end
end

function Reflection.make_constructor(
M::Type{<:AbstractModelWrapper{Model}},
closures::Vector,
params::Vector,
T::Type,
) where {Model}
num_closures = fieldcount(M) - 1 # ignore the `model` field
my_closures = closures[1:num_closures]
model_constructor =
Reflection.make_constructor(Model, closures[num_closures+1:end], params, T)
:($(Base.typename(M).name)($(model_constructor), $(my_closures...)))
end
15 changes: 10 additions & 5 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ function reassemble_model(model::Type{<:AbstractSpectralModel}, parameters)
T = eltype(parameters)
info = get_info(model, DEFAULT_LENS)
closure_assigns = Expr[]
for f in info.closure_symbols
c_lens = :(getfield($(info.lens), $(Meta.quot(f))))
for c_lens in closure_parameter_lenses(info)
push!(closure_assigns, c_lens)
end

Expand Down Expand Up @@ -320,8 +319,8 @@ function assemble_composite_model_call(model::Type{<:CompositeModel}, parameters
push!(parameter_assigns, assignment)
end

for (f, c) in zip(info.closure_symbols, info.generated_closure_symbols)
c_lens = :(getfield($(info.lens), $(Meta.quot(f))))
for (c_lens, c) in
zip(closure_parameter_lenses(info), info.generated_closure_symbols)
c_assingment = :($c = $c_lens)
push!(closure_assigns, c_assingment)
end
Expand Down Expand Up @@ -407,9 +406,15 @@ function parameter_lenses(::Type{<:AbstractSpectralModel}, info::ModelInfo)
:(getfield($(info.lens), $(Meta.quot(symb))))
end
end
# parameter_lenses(infos::Vector{<:ModelInfo}) = reduce(vcat, map(parameter_lenses(info, info.symbols), infos))
parameter_lenses(info::ModelInfo) = parameter_lenses(info.model, info)

function closure_parameter_lenses(::Type{<:AbstractSpectralModel}, info::ModelInfo)
map(info.closure_symbols) do symb
:(getfield($(info.lens), $(Meta.quot(symb))))
end
end
closure_parameter_lenses(info::ModelInfo) = closure_parameter_lenses(info.model, info)

end # module Reflection

# public API wrappers
Expand Down
67 changes: 67 additions & 0 deletions test/models/test-auto-cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using Test
using SpectralFitting

include("../dummies.jl")

struct EvalCountingModel{D,T} <: AbstractTableModel{T,Additive}
table::D
K::T
a::T
end

function EvalCountingModel(; K = FitParam(1.0), a = FitParam(3.0))
EvalCountingModel(Int[0], K, a)
end

function SpectralFitting.invoke!(output, domain, model::EvalCountingModel)
model.table[1] += 1
@. output = domain[1:end-1] .+ model.a
end

domain = collect(range(0.0, 10.0, 100))

model = AutoCache(EvalCountingModel())

# running the model several times should only hit the counter once
for i = 1:100
invokemodel(domain, model)
end
@test model.model.table[1] == 1

# changing the parameter should hit the counter again
set_value!(model.model.a, 5.0)
for i = 1:100
invokemodel(domain, model)
end
@test model.model.table[1] == 2

# modifying the domain should change the cache as well
domain = collect(range(0.1, 5.0, 10))
for i = 1:100
invokemodel(domain, model)
end
@test model.model.table[1] == 3

# now as a composite
domain = collect(range(0.0, 10.0, 100))
model = DummyMultiplicative() * AutoCache(EvalCountingModel())

# running the model several times should only hit the counter once
for i = 1:100
invokemodel(domain, model)
end
@test model.a1.model.table[1] == 1

# changing the parameter should hit the counter again
set_value!(model.a_1, 5.0)
for i = 1:100
invokemodel(domain, model)
end
@test model.a1.model.table[1] == 2

# modifying the domain should change the cache as well
domain = collect(range(0.1, 5.0, 10))
for i = 1:100
invokemodel(domain, model)
end
@test model.a1.model.table[1] == 3
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ end
include("models/test-model-consistency.jl")
include("models/test-table-models.jl")
include("models/test-surrogate-models.jl")
include("models/test-auto-cache.jl")

# only test XSPEC models when not using CI
# since model data access is annoying
Expand Down

0 comments on commit c336468

Please sign in to comment.