diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py new file mode 100644 index 0000000000..429000b2a2 --- /dev/null +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -0,0 +1,592 @@ +from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ + stop_annotating, no_annotations, get_working_tape, set_working_tape +from pyadjoint.enlisting import Enlist +from functools import wraps, cached_property +from typing import Callable, Optional + +__all__ = ['AllAtOnceReducedFunctional'] + + +def sc_passthrough(func): + """ + Wraps standard ReducedFunctional methods to differentiate strong or + weak constraint implementations. + + If using strong constraint, makes sure strong_reduced_functional + is instantiated then passes args/kwargs through to the + corresponding strong_reduced_functional method. + + If using weak constraint, returns the AllAtOnceReducedFunctional + method definition. + """ + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.weak_constraint: + return func(self, *args, **kwargs) + else: + sc_func = getattr(self.strong_reduced_functional, func.__name__) + return sc_func(*args, **kwargs) + return wrapper + + +def _rename(obj, name): + if hasattr(obj, "rename"): + obj.rename(name) + + +def _ad_sub(left, right): + result = right._ad_copy() + result._ad_imul(-1) + result._ad_iadd(left) + return result + + +class AllAtOnceReducedFunctional(ReducedFunctional): + """ReducedFunctional for 4DVar data assimilation. + + Creates either the strong constraint or weak constraint system incrementally + by logging observations through the initial forward model run. + + Warning: Weak constraint 4DVar not implemented yet. + + Parameters + ---------- + + control + The initial condition :math:`x_{0}`. Starting value is used as the + background (prior) data :math:`x_{b}`. + + background_iprod + The inner product to calculate the background error functional + from the background error :math:`x_{0} - x_{b}`. Can include the + error covariance matrix. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the + observations at the initial time and :math:`\\mathcal{H}_{0}` is + the observation operator for the initial time. Optional. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. + Can include the error covariance matrix. Must be provided if + observation_err is provided. + + weak_constraint + Whether to use the weak or strong constraint 4DVar formulation. + + tape + The tape to record on. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + """ + + def __init__(self, control: Control, + background_iprod: Callable[[OverloadedType], AdjFloat], + observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, + observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, + weak_constraint: bool = True, + tape: Optional[Tape] = None, + _annotate_accumulation: bool = False): + + self.tape = get_working_tape() if tape is None else tape + + self.weak_constraint = weak_constraint + self.initial_observations = observation_err is not None + + # We need a copy for the prior, but this shouldn't be part of the tape + with stop_annotating(): + self.background = control.copy_data() + + if self.weak_constraint: + self._annotate_accumulation = _annotate_accumulation + + # new tape for background error vector + with set_working_tape() as tape: + # start from a control independent of any other tapes + with stop_annotating(): + control_copy = control.copy_data() + _rename(control_copy, "Control_0_bkg_copy") + + # vector of x_0 - x_b + bkg_err_vec = _ad_sub(control_copy, self.background) + _rename(bkg_err_vec, "bkg_err_vec") + + # RF to recover x_0 - x_b + self.background_error = ReducedFunctional( + bkg_err_vec, Control(control_copy), tape=tape) + + # new tape for background error reduction + with set_working_tape() as tape: + # start from a control independent of any other tapes + with stop_annotating(): + bkg_err_vec_copy = bkg_err_vec._ad_copy() + _rename(bkg_err_vec_copy, "bkg_err_vec_copy") + + # inner product |x_0 - x_b|_B + bkg_err = background_iprod(bkg_err_vec_copy) + + # RF to recover |x_0 - x_b|_B + self.background_rf = ReducedFunctional( + bkg_err, Control(bkg_err_vec_copy), tape=tape) + + self.controls = [control] # The solution at the beginning of each time-chunk + self.states = [] # The model propogation at the end of each time-chunk + self.forward_model_stages = [] # ReducedFunctional for each model propogation (returns state) + self.forward_model_errors = [] # Inner product for model errors (possibly including error covariance) + self.forward_model_rfs = [] # Inner product for model errors (possibly including error covariance) + self.observation_errors = [] # ReducedFunctional for each observation set (returns observation error) + self.observation_rfs = [] # Inner product for observation errors (possibly including error covariance) + + if self.initial_observations: + + # new tape for observation error vector + with set_working_tape() as tape: + # start from a control independent of any other tapes + with stop_annotating(): + control_copy = control.copy_data() + _rename(control_copy, "Control_0_obs_copy") + + # vector of H(x_0) - y_0 + obs_err_vec = observation_err(control_copy) + _rename(obs_err_vec, "obs_err_vec_0") + + # RF to recover H(x_0) - y_0 + self.observation_errors.append(ReducedFunctional( + obs_err_vec, Control(control_copy), tape=tape) + ) + + # new tape for observation error reduction + with set_working_tape() as tape: + # start from a control independent of any othe tapes + with stop_annotating(): + obs_err_vec_copy = obs_err_vec._ad_copy() + _rename(obs_err_vec_copy, "obs_err_vec_0_copy") + + # inner product |H(x_0) - y_0|_R + obs_err = observation_iprod(obs_err_vec_copy) + + # RF to recover |H(x_0) - y_0|_R + self.observation_rfs.append(ReducedFunctional( + obs_err, Control(obs_err_vec_copy), tape=tape) + ) + + # new tape for the next stage + set_working_tape() + self._stage_tape = get_working_tape() + + else: + self._annotate_accumulation = True + + # initial conditions guess to be updated + self.controls = Enlist(control) + + # Strong constraint functional to be converted to ReducedFunctional later + + # penalty for straying from prior + self._accumulate_functional( + background_iprod(control.control - self.background)) + + # penalty for not hitting observations at initial time + if self.initial_observations: + self._accumulate_functional( + observation_iprod(observation_err(control.control))) + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat], + forward_model_iprod: Optional[Callable[[OverloadedType], AdjFloat]]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + + forward_model_iprod + The inner product to calculate the model error functional from + the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can + include the error covariance matrix. Ignored if using the strong + constraint formulation. + """ + if self.weak_constraint: + + stage_index = len(self.controls) + + # Cut the tape into seperate time-chunks. + # State is output from previous control i.e. forward model + # propogation over previous time-chunk. + + # get the tape used for this stage and make sure its the right one + prev_stage_tape = get_working_tape() + if prev_stage_tape is not self._stage_tape: + raise ValueError( + "Working tape at the end of the observation stage" + " differs from the tape at the stage beginning." + ) + + # # record forward propogation + with set_working_tape(prev_stage_tape.copy()) as tape: + prev_control = self.controls[-1] + self.forward_model_stages.append(ReducedFunctional( + state._ad_copy(), controls=prev_control, tape=tape) + ) + + # Beginning of next time-chunk is the control for this observation + # and the state at the end of the next time-chunk. + with stop_annotating(): + # smuggle initial guess at this time into the control without the tape seeing + next_control_state = state._ad_copy() + _rename(next_control_state, f"Control_{len(self.controls)}") + next_control = Control(next_control_state) + self.controls.append(next_control) + + # model error links time-chunks by depending on both the + # previous and current controls + + # new tape for model error vector + with set_working_tape() as tape: + # start from a control independent of any other tapes + with stop_annotating(): + state_copy = state._ad_copy() + _rename(state_copy, f"state_{stage_index}_copy") + next_control_copy = next_control.copy_data() + _rename(next_control_copy, f"Control_{stage_index}_model_copy") + + # vector of M_i - x_i + model_err_vec = _ad_sub(state_copy, next_control_copy) + _rename(model_err_vec, f"model_err_vec_{stage_index}") + + # RF to recover M_i - x_i + fmcontrols = [Control(state_copy), Control(next_control_copy)] + self.forward_model_errors.append(ReducedFunctional( + model_err_vec, fmcontrols, tape=tape) + ) + + # new tape for model error reduction + with set_working_tape() as tape: + # start from a control independent of any othe tapes + with stop_annotating(): + model_err_vec_copy = model_err_vec._ad_copy() + _rename(model_err_vec_copy, f"model_err_vec_{stage_index}_copy") + + # inner product |M_i - x_i|_Q + model_err = forward_model_iprod(model_err_vec_copy) + + # RF to recover |M_i - x_i|_Q + self.forward_model_rfs.append(ReducedFunctional( + model_err, Control(model_err_vec_copy), tape=tape) + ) + + # Observations after tape cut because this is now a control, not a state + + # new tape for observation error vector + with set_working_tape() as tape: + # start from a control independent of any other tapes + with stop_annotating(): + next_control_copy = next_control.copy_data() + _rename(next_control_copy, f"Control_{stage_index}_obs_copy") + + # vector of H(x_i) - y_i + obs_err_vec = observation_err(next_control_copy) + _rename(obs_err_vec, f"obs_err_vec_{stage_index}") + + # RF to recover H(x_i) - y_i + self.observation_errors.append(ReducedFunctional( + obs_err_vec, Control(next_control_copy), tape=tape) + ) + + # new tape for observation error reduction + with set_working_tape() as tape: + # start from a control independent of any othe tapes + with stop_annotating(): + obs_err_vec_copy = obs_err_vec._ad_copy() + _rename(obs_err_vec_copy, f"obs_err_vec_{stage_index}_copy") + + # inner product |H(x_i) - y_i|_R + obs_err = observation_iprod(obs_err_vec_copy) + + # RF to recover |H(x_i) - y_i|_R + self.observation_rfs.append(ReducedFunctional( + obs_err, Control(obs_err_vec_copy), tape=tape) + ) + + # new tape for the next stage + + set_working_tape() + self._stage_tape = get_working_tape() + + # Look we're starting this time-chunk from an "unrelated" value... really! + state.assign(next_control.control) + + else: + + if hasattr(self, "_strong_reduced_functional"): + msg = "Cannot add observations once strong constraint ReducedFunctional instantiated" + raise ValueError(msg) + self._accumulate_functional( + observation_iprod(observation_err(state))) + + @cached_property + def strong_reduced_functional(self): + """A :class:`pyadjoint.ReducedFunctional` for the strong constraint 4DVar system. + + Only instantiated if using the strong constraint formulation, and cannot be used + before all observations are recorded. + """ + if self.weak_constraint: + msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar" + raise AttributeError(msg) + self._strong_reduced_functional = ReducedFunctional( + self._total_functional, self.controls, tape=self.tape) + return self._strong_reduced_functional + + def __getattr__(self, attr): + """ + If using strong constraint then grab attributes from self.strong_reduced_functional. + """ + if self.weak_constraint: + raise AttributeError(f"'{type(self)}' object has no attribute '{attr}'") + else: + return getattr(self.strong_reduced_functional, attr) + + @sc_passthrough + @no_annotations + def __call__(self, values: OverloadedType): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + # controls are updated by the sub ReducedFunctionals + # so we don't need to do it ourselves + + # Shift lists so indexing matches standard nomenclature: + # index 0 is initial condition. + # Model i propogates from i-1 to i. + # Observation i is at i. + + for c, v in zip(self.controls, values): + c.control.assign(v) + + model_stages = [None, *self.forward_model_stages] + model_errors = [None, *self.forward_model_errors] + model_rfs = [None, *self.forward_model_rfs] + + observation_errors = (self.observation_errors if self.initial_observations + else [None, *self.observation_errors]) + + observation_rfs = (self.observation_rfs if self.initial_observations + else [None, *self.observation_rfs]) + + # Initial condition functionals + bkg_err_vec = self.background_error(values[0]) + J = self.background_rf(bkg_err_vec) + + # observations at time 0 + if self.initial_observations: + obs_err_vec = observation_errors[0](values[0]) + J += observation_rfs[0](obs_err_vec) + + for i in range(1, len(observation_rfs)): + prev_control = values[i-1] + this_control = values[i] + + # observation error - do we match the 'real world'? + obs_err_vec = observation_errors[i](this_control) + J += observation_rfs[i](obs_err_vec) + + # Model error - does propogation from last control match this control? + Mi = model_stages[i](prev_control) + model_err_vec = model_errors[i]([Mi, this_control]) + J += model_rfs[i](model_err_vec) + + return J + + @sc_passthrough + @no_annotations + def derivative(self, adj_input: float = 1.0, options: dict = {}): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + """ + # create a list of overloaded types to put derivative into + derivatives = [] + + # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + intermediate_options = { + 'riesz_representation': None, + **{k: v for k, v in options.items() + if (k != 'riesz_representation')} + } + + # Shift lists so indexing matches standard nomenclature: + # index 0 is initial condition. Model i propogates from i-1 to i. + model_stages = [None, *self.forward_model_stages] + model_errors = [None, *self.forward_model_errors] + model_rfs = [None, *self.forward_model_rfs] + + observation_errors = (self.observation_errors if self.initial_observations + else [None, *self.observation_errors]) + + observation_rfs = (self.observation_rfs if self.initial_observations + else [None, *self.observation_rfs]) + + # initial condition derivatives + bkg_deriv = self.background_rf.derivative(adj_input=adj_input, + options=intermediate_options) + derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, + options=options)) + + # observations at time 0 + if self.initial_observations: + obs_deriv = observation_rfs[0].derivative(adj_input=adj_input, + options=intermediate_options) + derivatives[0] += observation_errors[0].derivative(adj_input=obs_deriv, + options=options) + + for i in range(1, len(observation_rfs)): + obs_deriv = observation_rfs[i].derivative(adj_input=adj_input, + options=intermediate_options) + derivatives.append(observation_errors[i].derivative(adj_input=obs_deriv, + options=options)) + + # derivative of model error will split: + # wrt x_i through error vector + # wrt x_i-1 through stage propogation + model_deriv = model_rfs[i].derivative(adj_input=adj_input, + options=intermediate_options) + model_err_derivs = model_errors[i].derivative(adj_input=model_deriv, + options=intermediate_options) + model_stage_deriv = model_stages[i].derivative(adj_input=model_err_derivs[0], + options=options) + + derivatives[i-1] += model_stage_deriv + derivatives[i] += model_err_derivs[1].riesz_representation() + + return derivatives + + @sc_passthrough + @no_annotations + def hessian(self, m_dot: OverloadedType, options: dict = {}): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + # create a list of overloaded types to put hessian into + hessians = [] + + kwargs = {'options': options} + + # Shift lists so indexing matches standard nomenclature: + # index 0 is initial condition. Model i propogates from i-1 to i. + model_rfs = [None, *self.forward_model_rfs] + + observation_rfs = (self.observation_rfs if self.initial_observations + else [None, *self.observation_rfs]) + + # initial condition hessians + hessians.append( + self.background_rf.hessian(m_dot[0], **kwargs)) + + if self.initial_observations: + hessians[0] += observation_rfs[0].hessian(m_dot[0], **kwargs) + + for i in range(1, len(model_rfs)): + hessians.append(observation_rfs[i].hessian(m_dot[i], **kwargs)) + + mhess = model_rfs[i].hessian(m_dot[i-1:i+1], **kwargs) + + hessians[i-1] += mhess[0] + hessians[i] += mhess[1] + + return hessians + + @no_annotations + def hessian_matrix(self): + # Other reduced functionals don't have this. + if not self.weak_constraint: + raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix") + raise NotImplementedError + + @sc_passthrough + @no_annotations + def optimize_tape(self): + for rf in (self.background_error, + self.background_rf, + *self.observation_errors, + *self.observation_rfs, + *self.forward_model_stages, + *self.forward_model_errors, + *self.forward_model_rfs): + rf.optimize_tape() + + def _accumulate_functional(self, val): + if not self._annotate_accumulation: + return + if hasattr(self, '_total_functional'): + self._total_functional += val + else: + self._total_functional = val