Skip to content

Commit

Permalink
Merge pull request #129 from fjebaker/fergus/fix-bindings
Browse files Browse the repository at this point in the history
Fix parameter bindings
  • Loading branch information
fjebaker authored Oct 2, 2024
2 parents b2f5ef4 + f536f63 commit dc00d6b
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 28 deletions.
25 changes: 22 additions & 3 deletions src/composite-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,25 @@ function Base.show(io::IO, @nospecialize(model::CompositeModel))
)
end

function _print_param(io, free, name, val, q0, q1, q2, q3, q4)
function _print_param(io, free, name, val, q0, q1, q2, q3, q4; binding = nothing)
print(io, lpad("$name", q0), " ->")
if val isa FitParam
info = get_info_tuple(val)
print(io, lpad(info[1], q1 + 1))
if free
print(io, " Β± ", rpad(info[2], q2))
print(io, " ∈ [", lpad(info[3], q3), ", ", rpad(info[4], q4), "]")
end

if !isnothing(binding)
print(
io,
" ",
Crayons.Crayon(foreground = :magenta),
lpad(binding, 7),
Crayons.Crayon(reset = true),
)
elseif free
print(
io,
Crayons.Crayon(foreground = :green),
Expand All @@ -174,6 +185,7 @@ function _print_param(io, free, name, val, q0, q1, q2, q3, q4)
else
print(
io,
" ",
Crayons.Crayon(foreground = :cyan),
lpad("FROZEN", 15 + q1 + q2 + q3 + q4),
Crayons.Crayon(reset = true),
Expand All @@ -183,7 +195,7 @@ function _print_param(io, free, name, val, q0, q1, q2, q3, q4)
println(io)
end

function _printinfo(io::IO, @nospecialize(model::CompositeModel))
function _printinfo(io::IO, @nospecialize(model::CompositeModel); bindings = nothing)
expr_buffer = 5
sym_buffer = 5

Expand All @@ -208,6 +220,7 @@ function _printinfo(io::IO, @nospecialize(model::CompositeModel))
)
println(io, "Model key and parameters:")

param_index = 1
for (sym, m) in destructed.model_map
param_syms = destructed.parameter_symbols[sym]
basename = Base.typename(typeof(m)).name
Expand All @@ -224,7 +237,13 @@ function _printinfo(io::IO, @nospecialize(model::CompositeModel))
for ps in param_syms
param = destructed.parameter_map[ps]
free = param isa FitParam ? !isfrozen(param) : true
_print_param(io, free, ps, param, param_offset, q1, q2, q3, q4)
val, binding = if !isnothing(bindings) && !isempty(bindings)
get(bindings, param_index, param => nothing)
else
param, nothing
end
_print_param(io, free, ps, val, param_offset, q1, q2, q3, q4; binding)
param_index += 1
end
end
end
Expand Down
6 changes: 0 additions & 6 deletions src/fitting/binding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ function _get_index_of_symbol(model::AbstractSpectralModel, symbol)::Int
if isnothing(i)
error("Could not match symbol $symbol !")
end
# don't count frozen parameters if they are prior to the parameter of interest
for s = 1:i
if isfrozen(pnt[s])
i -= 1
end
end
i
end

Expand Down
59 changes: 54 additions & 5 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,73 @@ function _invoke_and_transform!(cache::MultiModelCache, domain, params)
all_outputs
end

"""
adjust_free_bindings(model::FittableMultiModel, bindings::Vector{Vector{Pair{Int,Int}}})
Returns a new parameter binding list with the parameter indices adjusted to
omit the frozen parameter. That is, if a model has three parameters `a`, `b`
(frozen) and `c`, and parameter `c` (index 3) is bound; then the new bindings
will instead give the index of `c` as 2.
In that sense, the new bindings refer to the _free parameter_ vector, and not the full parameter vector.
If the binding refers to a frozen parameter, it is removed from the binding list.
"""
function adjust_free_bindings(model::FittableMultiModel, bindings)
new_bindings = map(bindings) do binding
pairs = Vector{Pair{Int,Int}}()
for pair in binding
model_index, parameter_index = pair

new_index = parameter_index
refers_to_frozen = false

for (i, param) in enumerate(parameter_tuple(model.m[model_index]))
if i > parameter_index
break
end

if isfrozen(param)
if i == parameter_index
refers_to_frozen = true
break
end
new_index -= 1
end
end

if !refers_to_frozen
push!(pairs, model_index => new_index)
end
end
pairs
end

# remove those that have become redundant
filter(i -> length(i) > 1, new_bindings)
end

function _build_parameter_mapping(model::FittableMultiModel{M}, bindings) where {M}
T = paramtype(M.parameters[1])

all_parameters = Vector{T}()
all_free_parameters = Vector{T}()
# use the tuple hack to enforce type stability and unroll the loop
parameters = map((1:length(M.parameters)...,)) do i
m = model.m[i]
v::Vector{T} = collect(filter(isfree, parameter_tuple(m)))
append!(all_parameters, v)
append!(all_free_parameters, v)
v
end
parameters_counts = _accumulated_indices(map(length, parameters))

parameter_mapping, remove = _construct_bound_mapping(bindings, parameters_counts)
# need to adjust the bindings to remove the frozen indices
free_bindings = adjust_free_bindings(model, bindings)

parameter_mapping, remove = _construct_bound_mapping(free_bindings, parameters_counts)
# remove duplicate parameters that are bound
deleteat!(all_parameters, remove)
deleteat!(all_free_parameters, remove)

all_parameters, parameter_mapping
all_free_parameters, parameter_mapping
end

function _build_mapping_length(f, itt::Tuple)
Expand Down
70 changes: 65 additions & 5 deletions src/fitting/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,45 @@ struct FittableMultiModel{M}
FittableMultiModel(model::Vararg{<:AbstractSpectralModel}) = new{typeof(model)}(model)
end

function translate_bindings(
model_index::Int,
m::FittableMultiModel,
bindings::Vector{Vector{Pair{Int,Int}}},
)
# map the parameter index to a string to display
translation = Dict{Int,Pair{paramtype(m.m[1]),String}}()
for b in bindings
# we skip the first item in the list since that is the root of the binding
root = b[1]

params = parameter_named_tuple(m.m[first(root)])
symbol = propertynames(params)[last(root)]
value = params[last(root)]

for pair in @views b[2:end]
# check if this binding applies to the current model
if first(pair) == model_index
translation[last(pair)] = value => "~Model $(first(root)) $(symbol)"
end
end
end

translation
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(model::FittableMultiModel))
buff = IOBuffer()
println(buff, "Models:")
for m in model.m
for (i, m) in enumerate(model.m)
buf = IOBuffer()
print(buf, "- ")
print(
buf,
"\n",
Crayons.Crayon(foreground = :yellow),
"Model $i",
Crayons.Crayon(reset = true),
": ",
)
_printinfo(buf, m)
print(buff, indent(String(take!(buf)), 2))
end
Expand Down Expand Up @@ -135,11 +168,38 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(prob::FittingProble
buff,
" . ",
Crayons.Crayon(foreground = :green),
"Free (DOF)",
"Free",
Crayons.Crayon(reset = true),
" : $(free)",
" : $(free)",
)
print(io, encapsulate(String(take!(buff))))
end

export FittingProblem, FittableMultiModel, FittableMultiDataset

"""
details(prob::FittingProblem)
Show details about the fitting problem, including the specific model parameters that are bound together.
"""
function details(prob::FittingProblem)
buff = IOBuffer()
println(buff, "Models:")
for (i, m) in enumerate(prob.model.m)
buf = IOBuffer()
print(
buf,
"\n",
Crayons.Crayon(foreground = :yellow),
"Model $i",
Crayons.Crayon(reset = true),
": ",
)

_printinfo(buf, m; bindings = translate_bindings(i, prob.model, prob.bindings))
print(buff, indent(String(take!(buf)), 2))
end

print(encapsulate(String(take!(buff))))
end

export FittingProblem, FittableMultiModel, FittableMultiDataset, details
27 changes: 27 additions & 0 deletions test/fitting/test-binding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ prob = FittingProblem(model1 => dummy_data1, model1 => dummy_data1, model1 => du
model1.K_1.frozen = true
bind!(prob, :a_1)
bind!(prob, :a_2)

# check that the free parameter binding adjustment works okay
new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings)
@test new_bindings == [[1 => 1, 2 => 1, 3 => 1], [1 => 3, 2 => 3, 3 => 3]]

_, mapping = SpectralFitting._build_parameter_mapping(prob.model, prob.bindings)
@test mapping == ([1, 2, 3], [1, 4, 3], [1, 5, 3])

Expand All @@ -94,5 +99,27 @@ model1.K_1.frozen = false
model1.a_2.frozen = true
bind!(prob, :K_1)
bind!(prob, :a_1)

new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings)
@test new_bindings == [[1 => 1, 2 => 1, 3 => 1], [1 => 2, 2 => 2, 3 => 2]]

_, mapping = SpectralFitting._build_parameter_mapping(prob.model, prob.bindings)
@test mapping == ([1, 2, 3], [1, 2, 4], [1, 2, 5])


prob = FittingProblem(model1 => dummy_data1, model1 => dummy_data1, model1 => dummy_data1)

# bind model 1's K_1 parameter to model 2's K_2 parameter
bind!(prob, 1 => :K_1, 2 => :K_2)

new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings)
@test new_bindings == [[1 => 1, 2 => 3]]

# bind model 1's a_1 parameter to model 2's a_2 to model 3's a_2
# these are all frozen so should not appear in the adjust_free_bindings
bind!(prob, 1 => :a_1, 2 => :a_2, 3 => :a_2)

new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings)
@test new_bindings == [[1 => 1, 2 => 3]]

# TODO: free parameters should not be allowed to bind to frozen parameters
18 changes: 9 additions & 9 deletions test/io/test-printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ expected = """
β”Œ DummyAdditive
β”‚ K -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ a -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b -> 5\e[36m FROZEN\e[0m
β”‚ b -> 5 \e[36m FROZEN\e[0m
β”” """
@test string == expected

Expand All @@ -32,14 +32,14 @@ expected = """β”Œ CompositeModel with 3 model components:
β”‚ \e[36m a1\e[0m => \e[36mDummyAdditive\e[0m
β”‚ K_1 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ a_1 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_1 -> 5\e[36m FROZEN\e[0m
β”‚ b_1 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m a2\e[0m => \e[36mDummyAdditive\e[0m
β”‚ K_2 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ a_2 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_2 -> 5\e[36m FROZEN\e[0m
β”‚ b_2 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m m1\e[0m => \e[36mDummyMultiplicative\e[0m
β”‚ a_3 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_3 -> 5\e[36m FROZEN\e[0m
β”‚ b_3 -> 5 \e[36m FROZEN\e[0m
β”” """
@test string == expected

Expand All @@ -55,20 +55,20 @@ expected = """β”Œ CompositeModel with 5 model components:
β”‚ \e[36m a1\e[0m => \e[36mDummyAdditive\e[0m
β”‚ K_1 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ a_1 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_1 -> 5\e[36m FROZEN\e[0m
β”‚ b_1 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m m1\e[0m => \e[36mDummyMultiplicative\e[0m
β”‚ a_2 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_2 -> 5\e[36m FROZEN\e[0m
β”‚ b_2 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m a2\e[0m => \e[36mDummyAdditive\e[0m
β”‚ K_2 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ a_3 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_3 -> 5\e[36m FROZEN\e[0m
β”‚ b_3 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m m2\e[0m => \e[36mDummyMultiplicative\e[0m
β”‚ a_4 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_4 -> 5\e[36m FROZEN\e[0m
β”‚ b_4 -> 5 \e[36m FROZEN\e[0m
β”‚ \e[36m m3\e[0m => \e[36mDummyMultiplicative\e[0m
β”‚ a_5 -> 1 Β± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m
β”‚ b_5 -> 5\e[36m FROZEN\e[0m
β”‚ b_5 -> 5 \e[36m FROZEN\e[0m
β”” """
@test string == expected

Expand Down

0 comments on commit dc00d6b

Please sign in to comment.