-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #117 from fjebaker/fergus/wrappers
feat: model wrappers and AutoCache
- Loading branch information
Showing
7 changed files
with
230 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
mutable struct CacheEntry{T} | ||
cache::Vector{T} | ||
params::Vector{T} | ||
domain_limits::Tuple{T,T} | ||
size_of_element::Int | ||
|
||
function CacheEntry(params::AbstractVector{<:Number}) | ||
T = eltype(params) | ||
cache = zeros(T, 1) | ||
new{T}(cache, params, (zero(T), zero(T)), sizeof(T)) | ||
end | ||
end | ||
|
||
struct AutoCache{M,T,K,C<:CacheEntry} <: AbstractModelWrapper{M,T,K} | ||
model::M | ||
cache::C | ||
abstol::Float64 | ||
function AutoCache( | ||
model::AbstractSpectralModel{T,K}, | ||
cache::CacheEntry, | ||
abstol, | ||
) where {T,K} | ||
new{typeof(model),T,K,typeof(cache)}(model, cache, abstol) | ||
end | ||
end | ||
|
||
function AutoCache(model::AbstractSpectralModel{T,K}; abstol = 1e-3) where {T,K} | ||
params = [get_value.(parameter_tuple(model))...] | ||
cache = CacheEntry(params) | ||
AutoCache(model, cache, abstol) | ||
end | ||
|
||
function _reinterpret_dual(::Type, v::AbstractArray, n::Int) | ||
needs_resize = n > length(v) | ||
if needs_resize | ||
@warn "AutoCache: Growing dual buffer..." | ||
resize!(v, n) | ||
end | ||
view(v, 1:n), needs_resize | ||
end | ||
function _reinterpret_dual( | ||
DualType::Type{<:ForwardDiff.Dual}, | ||
v::AbstractArray{T}, | ||
n::Int, | ||
) where {T} | ||
n_elems = div(sizeof(DualType), sizeof(T)) * n | ||
needs_resize = n_elems > length(v) | ||
if needs_resize | ||
@warn "AutoCache: Growing dual buffer..." | ||
resize!(v, n_elems) | ||
end | ||
reinterpret(DualType, view(v, 1:n_elems)), needs_resize | ||
end | ||
|
||
function invoke!(output, domain, model::AutoCache{M,T}) where {M,T} | ||
D = promote_type(eltype(domain), T) | ||
|
||
_new_params = parameter_tuple(model.model) | ||
_new_limits = (first(domain), last(domain)) | ||
|
||
output_cache, out_resized = _reinterpret_dual(D, model.cache.cache, length(output)) | ||
param_cache, _ = _reinterpret_dual(D, model.cache.params, length(_new_params)) | ||
|
||
same_domain = model.cache.domain_limits == _new_limits | ||
|
||
# if the parameter size has changed, need to rerun the model | ||
if (!out_resized) && (model.cache.size_of_element == sizeof(D)) && (same_domain) | ||
# if all parameters within some tolerance, then just return the cache | ||
within_tolerance = all(zip(param_cache, _new_params)) do I | ||
p, pm = I | ||
abs((get_value(p) - get_value(pm)) / p) < model.abstol | ||
end | ||
|
||
if within_tolerance | ||
@. output = output_cache | ||
return output | ||
end | ||
end | ||
|
||
model.cache.size_of_element = sizeof(D) | ||
invoke!(output_cache, domain, model.model) | ||
# update the auto cache infos | ||
model.cache.domain_limits = _new_limits | ||
|
||
@. param_cache = get_value(_new_params) | ||
@. output = output_cache | ||
end | ||
|
||
export AutoCache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
abstract type AbstractModelWrapper{M,T,K} <: AbstractSpectralModel{T,K} end | ||
First field of the struct must be `model`. | ||
""" | ||
abstract type AbstractModelWrapper{M<:AbstractSpectralModel,T,K} <: | ||
AbstractSpectralModel{T,K} end | ||
|
||
normalisation(model::AbstractModelWrapper{M,T,Additive}) where {M,T} = | ||
normalisation(model.model) | ||
|
||
function Reflection.get_closure_symbols( | ||
M::Type{<:AbstractModelWrapper{Model}}, | ||
) where {Model} | ||
# we ignore the `model` field, since that will be given by the constructor | ||
(fieldnames(M)[2:end]..., Reflection.get_closure_symbols(Model)...) | ||
end | ||
|
||
Reflection.get_parameter_symbols(::Type{<:AbstractModelWrapper{M}}) where {M} = | ||
Reflection.get_parameter_symbols(M) | ||
|
||
function Reflection.closure_parameter_lenses( | ||
M::Type{<:AbstractModelWrapper}, | ||
info::Reflection.ModelInfo, | ||
) | ||
num_closures = fieldcount(M) - 1 # ignore the `model` field | ||
|
||
my_closures = map(info.closure_symbols[1:num_closures]) do s | ||
:(getfield($(info.lens), $(Meta.quot(s)))) | ||
end | ||
model_closures = map(info.closure_symbols[num_closures+1:end]) do s | ||
:(getfield($(info.lens).model, $(Meta.quot(s)))) | ||
end | ||
vcat(my_closures, model_closures) | ||
end | ||
|
||
function Reflection.parameter_lenses( | ||
::Type{<:AbstractModelWrapper}, | ||
info::Reflection.ModelInfo, | ||
) | ||
map(info.symbols) do s | ||
:(getfield($(info.lens).model, $(Meta.quot(s)))) | ||
end | ||
end | ||
|
||
function Reflection.make_constructor( | ||
M::Type{<:AbstractModelWrapper{Model}}, | ||
closures::Vector, | ||
params::Vector, | ||
T::Type, | ||
) where {Model} | ||
num_closures = fieldcount(M) - 1 # ignore the `model` field | ||
my_closures = closures[1:num_closures] | ||
model_constructor = | ||
Reflection.make_constructor(Model, closures[num_closures+1:end], params, T) | ||
:($(Base.typename(M).name)($(model_constructor), $(my_closures...))) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
using Test | ||
using SpectralFitting | ||
|
||
include("../dummies.jl") | ||
|
||
struct EvalCountingModel{D,T} <: AbstractTableModel{T,Additive} | ||
table::D | ||
K::T | ||
a::T | ||
end | ||
|
||
function EvalCountingModel(; K = FitParam(1.0), a = FitParam(3.0)) | ||
EvalCountingModel(Int[0], K, a) | ||
end | ||
|
||
function SpectralFitting.invoke!(output, domain, model::EvalCountingModel) | ||
model.table[1] += 1 | ||
@. output = domain[1:end-1] .+ model.a | ||
end | ||
|
||
domain = collect(range(0.0, 10.0, 100)) | ||
|
||
model = AutoCache(EvalCountingModel()) | ||
|
||
# running the model several times should only hit the counter once | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.model.table[1] == 1 | ||
|
||
# changing the parameter should hit the counter again | ||
set_value!(model.model.a, 5.0) | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.model.table[1] == 2 | ||
|
||
# modifying the domain should change the cache as well | ||
domain = collect(range(0.1, 5.0, 10)) | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.model.table[1] == 3 | ||
|
||
# now as a composite | ||
domain = collect(range(0.0, 10.0, 100)) | ||
model = DummyMultiplicative() * AutoCache(EvalCountingModel()) | ||
|
||
# running the model several times should only hit the counter once | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.a1.model.table[1] == 1 | ||
|
||
# changing the parameter should hit the counter again | ||
set_value!(model.a_1, 5.0) | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.a1.model.table[1] == 2 | ||
|
||
# modifying the domain should change the cache as well | ||
domain = collect(range(0.1, 5.0, 10)) | ||
for i = 1:100 | ||
invokemodel(domain, model) | ||
end | ||
@test model.a1.model.table[1] == 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters