diff --git a/src/composite-models.jl b/src/composite-models.jl index 283b02ba..f8d16a4f 100644 --- a/src/composite-models.jl +++ b/src/composite-models.jl @@ -157,7 +157,7 @@ function Base.show(io::IO, @nospecialize(model::CompositeModel)) ) end -function _print_param(io, free, name, val, q0, q1, q2, q3, q4) +function _print_param(io, free, name, val, q0, q1, q2, q3, q4; binding = nothing) print(io, lpad("$name", q0), " ->") if val isa FitParam info = get_info_tuple(val) @@ -165,6 +165,17 @@ function _print_param(io, free, name, val, q0, q1, q2, q3, q4) if free print(io, " ± ", rpad(info[2], q2)) print(io, " ∈ [", lpad(info[3], q3), ", ", rpad(info[4], q4), "]") + end + + if !isnothing(binding) + print( + io, + " ", + Crayons.Crayon(foreground = :magenta), + lpad(binding, 7), + Crayons.Crayon(reset = true), + ) + elseif free print( io, Crayons.Crayon(foreground = :green), @@ -174,6 +185,7 @@ function _print_param(io, free, name, val, q0, q1, q2, q3, q4) else print( io, + " ", Crayons.Crayon(foreground = :cyan), lpad("FROZEN", 15 + q1 + q2 + q3 + q4), Crayons.Crayon(reset = true), @@ -183,7 +195,7 @@ function _print_param(io, free, name, val, q0, q1, q2, q3, q4) println(io) end -function _printinfo(io::IO, @nospecialize(model::CompositeModel)) +function _printinfo(io::IO, @nospecialize(model::CompositeModel); bindings = nothing) expr_buffer = 5 sym_buffer = 5 @@ -208,6 +220,7 @@ function _printinfo(io::IO, @nospecialize(model::CompositeModel)) ) println(io, "Model key and parameters:") + param_index = 1 for (sym, m) in destructed.model_map param_syms = destructed.parameter_symbols[sym] basename = Base.typename(typeof(m)).name @@ -224,7 +237,13 @@ function _printinfo(io::IO, @nospecialize(model::CompositeModel)) for ps in param_syms param = destructed.parameter_map[ps] free = param isa FitParam ? !isfrozen(param) : true - _print_param(io, free, ps, param, param_offset, q1, q2, q3, q4) + val, binding = if !isnothing(bindings) && !isempty(bindings) + get(bindings, param_index, param => nothing) + else + param, nothing + end + _print_param(io, free, ps, val, param_offset, q1, q2, q3, q4; binding) + param_index += 1 end end end diff --git a/src/fitting/binding.jl b/src/fitting/binding.jl index d90b312b..2fbe0245 100644 --- a/src/fitting/binding.jl +++ b/src/fitting/binding.jl @@ -50,12 +50,6 @@ function _get_index_of_symbol(model::AbstractSpectralModel, symbol)::Int if isnothing(i) error("Could not match symbol $symbol !") end - # don't count frozen parameters if they are prior to the parameter of interest - for s = 1:i - if isfrozen(pnt[s]) - i -= 1 - end - end i end diff --git a/src/fitting/cache.jl b/src/fitting/cache.jl index 41c8c99d..d9765ffa 100644 --- a/src/fitting/cache.jl +++ b/src/fitting/cache.jl @@ -111,24 +111,73 @@ function _invoke_and_transform!(cache::MultiModelCache, domain, params) all_outputs end +""" + adjust_free_bindings(model::FittableMultiModel, bindings::Vector{Vector{Pair{Int,Int}}}) + +Returns a new parameter binding list with the parameter indices adjusted to +omit the frozen parameter. That is, if a model has three parameters `a`, `b` +(frozen) and `c`, and parameter `c` (index 3) is bound; then the new bindings +will instead give the index of `c` as 2. + +In that sense, the new bindings refer to the _free parameter_ vector, and not the full parameter vector. + +If the binding refers to a frozen parameter, it is removed from the binding list. +""" +function adjust_free_bindings(model::FittableMultiModel, bindings) + new_bindings = map(bindings) do binding + pairs = Vector{Pair{Int,Int}}() + for pair in binding + model_index, parameter_index = pair + + new_index = parameter_index + refers_to_frozen = false + + for (i, param) in enumerate(parameter_tuple(model.m[model_index])) + if i > parameter_index + break + end + + if isfrozen(param) + if i == parameter_index + refers_to_frozen = true + break + end + new_index -= 1 + end + end + + if !refers_to_frozen + push!(pairs, model_index => new_index) + end + end + pairs + end + + # remove those that have become redundant + filter(i -> length(i) > 1, new_bindings) +end + function _build_parameter_mapping(model::FittableMultiModel{M}, bindings) where {M} T = paramtype(M.parameters[1]) - all_parameters = Vector{T}() + all_free_parameters = Vector{T}() # use the tuple hack to enforce type stability and unroll the loop parameters = map((1:length(M.parameters)...,)) do i m = model.m[i] v::Vector{T} = collect(filter(isfree, parameter_tuple(m))) - append!(all_parameters, v) + append!(all_free_parameters, v) v end parameters_counts = _accumulated_indices(map(length, parameters)) - parameter_mapping, remove = _construct_bound_mapping(bindings, parameters_counts) + # need to adjust the bindings to remove the frozen indices + free_bindings = adjust_free_bindings(model, bindings) + + parameter_mapping, remove = _construct_bound_mapping(free_bindings, parameters_counts) # remove duplicate parameters that are bound - deleteat!(all_parameters, remove) + deleteat!(all_free_parameters, remove) - all_parameters, parameter_mapping + all_free_parameters, parameter_mapping end function _build_mapping_length(f, itt::Tuple) diff --git a/src/fitting/problem.jl b/src/fitting/problem.jl index 85bd8602..e91e2f7f 100644 --- a/src/fitting/problem.jl +++ b/src/fitting/problem.jl @@ -12,12 +12,45 @@ struct FittableMultiModel{M} FittableMultiModel(model::Vararg{<:AbstractSpectralModel}) = new{typeof(model)}(model) end +function translate_bindings( + model_index::Int, + m::FittableMultiModel, + bindings::Vector{Vector{Pair{Int,Int}}}, +) + # map the parameter index to a string to display + translation = Dict{Int,Pair{paramtype(m.m[1]),String}}() + for b in bindings + # we skip the first item in the list since that is the root of the binding + root = b[1] + + params = parameter_named_tuple(m.m[first(root)]) + symbol = propertynames(params)[last(root)] + value = params[last(root)] + + for pair in @views b[2:end] + # check if this binding applies to the current model + if first(pair) == model_index + translation[last(pair)] = value => "~Model $(first(root)) $(symbol)" + end + end + end + + translation +end + function Base.show(io::IO, ::MIME"text/plain", @nospecialize(model::FittableMultiModel)) buff = IOBuffer() println(buff, "Models:") - for m in model.m + for (i, m) in enumerate(model.m) buf = IOBuffer() - print(buf, "- ") + print( + buf, + "\n", + Crayons.Crayon(foreground = :yellow), + "Model $i", + Crayons.Crayon(reset = true), + ": ", + ) _printinfo(buf, m) print(buff, indent(String(take!(buf)), 2)) end @@ -135,11 +168,38 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(prob::FittingProble buff, " . ", Crayons.Crayon(foreground = :green), - "Free (DOF)", + "Free", Crayons.Crayon(reset = true), - " : $(free)", + " : $(free)", ) print(io, encapsulate(String(take!(buff)))) end -export FittingProblem, FittableMultiModel, FittableMultiDataset + +""" + details(prob::FittingProblem) + +Show details about the fitting problem, including the specific model parameters that are bound together. +""" +function details(prob::FittingProblem) + buff = IOBuffer() + println(buff, "Models:") + for (i, m) in enumerate(prob.model.m) + buf = IOBuffer() + print( + buf, + "\n", + Crayons.Crayon(foreground = :yellow), + "Model $i", + Crayons.Crayon(reset = true), + ": ", + ) + + _printinfo(buf, m; bindings = translate_bindings(i, prob.model, prob.bindings)) + print(buff, indent(String(take!(buf)), 2)) + end + + print(encapsulate(String(take!(buff)))) +end + +export FittingProblem, FittableMultiModel, FittableMultiDataset, details diff --git a/test/fitting/test-binding.jl b/test/fitting/test-binding.jl index ea05ad0c..9af44133 100644 --- a/test/fitting/test-binding.jl +++ b/test/fitting/test-binding.jl @@ -85,6 +85,11 @@ prob = FittingProblem(model1 => dummy_data1, model1 => dummy_data1, model1 => du model1.K_1.frozen = true bind!(prob, :a_1) bind!(prob, :a_2) + +# check that the free parameter binding adjustment works okay +new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings) +@test new_bindings == [[1 => 1, 2 => 1, 3 => 1], [1 => 3, 2 => 3, 3 => 3]] + _, mapping = SpectralFitting._build_parameter_mapping(prob.model, prob.bindings) @test mapping == ([1, 2, 3], [1, 4, 3], [1, 5, 3]) @@ -94,5 +99,27 @@ model1.K_1.frozen = false model1.a_2.frozen = true bind!(prob, :K_1) bind!(prob, :a_1) + +new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings) +@test new_bindings == [[1 => 1, 2 => 1, 3 => 1], [1 => 2, 2 => 2, 3 => 2]] + _, mapping = SpectralFitting._build_parameter_mapping(prob.model, prob.bindings) @test mapping == ([1, 2, 3], [1, 2, 4], [1, 2, 5]) + + +prob = FittingProblem(model1 => dummy_data1, model1 => dummy_data1, model1 => dummy_data1) + +# bind model 1's K_1 parameter to model 2's K_2 parameter +bind!(prob, 1 => :K_1, 2 => :K_2) + +new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings) +@test new_bindings == [[1 => 1, 2 => 3]] + +# bind model 1's a_1 parameter to model 2's a_2 to model 3's a_2 +# these are all frozen so should not appear in the adjust_free_bindings +bind!(prob, 1 => :a_1, 2 => :a_2, 3 => :a_2) + +new_bindings = SpectralFitting.adjust_free_bindings(prob.model, prob.bindings) +@test new_bindings == [[1 => 1, 2 => 3]] + +# TODO: free parameters should not be allowed to bind to frozen parameters diff --git a/test/io/test-printing.jl b/test/io/test-printing.jl index 1e1b1770..7eb9cf80 100644 --- a/test/io/test-printing.jl +++ b/test/io/test-printing.jl @@ -17,7 +17,7 @@ expected = """ ┌ DummyAdditive │ K -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m │ a -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b -> 5\e[36m FROZEN\e[0m +│ b -> 5 \e[36m FROZEN\e[0m └ """ @test string == expected @@ -32,14 +32,14 @@ expected = """┌ CompositeModel with 3 model components: │ \e[36m a1\e[0m => \e[36mDummyAdditive\e[0m │ K_1 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m │ a_1 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_1 -> 5\e[36m FROZEN\e[0m +│ b_1 -> 5 \e[36m FROZEN\e[0m │ \e[36m a2\e[0m => \e[36mDummyAdditive\e[0m │ K_2 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m │ a_2 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_2 -> 5\e[36m FROZEN\e[0m +│ b_2 -> 5 \e[36m FROZEN\e[0m │ \e[36m m1\e[0m => \e[36mDummyMultiplicative\e[0m │ a_3 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_3 -> 5\e[36m FROZEN\e[0m +│ b_3 -> 5 \e[36m FROZEN\e[0m └ """ @test string == expected @@ -55,20 +55,20 @@ expected = """┌ CompositeModel with 5 model components: │ \e[36m a1\e[0m => \e[36mDummyAdditive\e[0m │ K_1 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m │ a_1 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_1 -> 5\e[36m FROZEN\e[0m +│ b_1 -> 5 \e[36m FROZEN\e[0m │ \e[36m m1\e[0m => \e[36mDummyMultiplicative\e[0m │ a_2 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_2 -> 5\e[36m FROZEN\e[0m +│ b_2 -> 5 \e[36m FROZEN\e[0m │ \e[36m a2\e[0m => \e[36mDummyAdditive\e[0m │ K_2 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m │ a_3 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_3 -> 5\e[36m FROZEN\e[0m +│ b_3 -> 5 \e[36m FROZEN\e[0m │ \e[36m m2\e[0m => \e[36mDummyMultiplicative\e[0m │ a_4 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_4 -> 5\e[36m FROZEN\e[0m +│ b_4 -> 5 \e[36m FROZEN\e[0m │ \e[36m m3\e[0m => \e[36mDummyMultiplicative\e[0m │ a_5 -> 1 ± 0.1 ∈ [ 0, Inf ]\e[32m FREE\e[0m -│ b_5 -> 5\e[36m FROZEN\e[0m +│ b_5 -> 5 \e[36m FROZEN\e[0m └ """ @test string == expected