Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A reduced functional that updates a control over space and time in one nonlinear solve. #3592

Draft
wants to merge 35 commits into
base: master
Choose a base branch
from
Draft
Changes from 9 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f1ca5d4
initial sketch functional
dham May 22, 2024
52bf904
aaorf wip
JHopeCollins Jul 24, 2024
fa7e75d
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Aug 15, 2024
7f0e42e
initial impl for strong constraint 4DVar reduced functional
JHopeCollins Aug 16, 2024
7315a4e
aaorf: separate error calculations and inner products for observation…
JHopeCollins Aug 17, 2024
9dfd674
aaorf: decorator for strong constraint method implementations
JHopeCollins Aug 17, 2024
80fcbe4
aaorf: make strong_reduced_functional a cached_property
JHopeCollins Aug 20, 2024
bc8a868
aaorf: optimize_tape method
JHopeCollins Aug 20, 2024
9a388e8
aaorf docstring
JHopeCollins Aug 20, 2024
29c5f8c
Update firedrake/adjoint/all_at_once_reduced_functional.py
JHopeCollins Aug 20, 2024
8eadf31
aaorf docstrings
JHopeCollins Aug 20, 2024
551ad2c
aaorf - weak constraint __call__ impl
JHopeCollins Aug 21, 2024
9345a82
aaorf derivative initial impl
JHopeCollins Aug 23, 2024
78aff18
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Aug 23, 2024
bbfb40d
fixed mark controls in aaorf
JHopeCollins Sep 19, 2024
30d8389
aaorf - use inbuilt cached_property
JHopeCollins Sep 19, 2024
f00ab63
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Sep 30, 2024
23f9d11
aaorf - use _accumulate_functional for strong constraint, weak constr…
JHopeCollins Oct 10, 2024
eee3cb2
aaorf - hessian by post-processing tapes of all sub-ReducedFunctionals
JHopeCollins Oct 10, 2024
b85fcc6
aaorf - forward __getattr__ to strong constraint ReducedFunctional
JHopeCollins Oct 10, 2024
d6ccf94
aaorf - each sub-ReducedFunctional has its own tape
JHopeCollins Oct 10, 2024
43da7cf
aaorf - separate tapes for error vectors and error reductions
JHopeCollins Oct 10, 2024
b7cee1d
Merge branch 'allatoncereducedfunctional' of https://github.com/fired…
JHopeCollins Oct 10, 2024
c10cd1a
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Oct 10, 2024
3f42d03
aaorf - type hinting in signature not docstring
JHopeCollins Oct 11, 2024
393ebbe
aaorf - split observation error tape - vector and reduction
JHopeCollins Oct 22, 2024
086877a
aaorf - split tapes: background to error vector and error reduction; …
JHopeCollins Oct 22, 2024
b7fecb2
aaorf - tidy up after splitting tapes: remove old code; pass riesz op…
JHopeCollins Oct 22, 2024
e48aba7
aaorf - docstrings
JHopeCollins Oct 22, 2024
b1020d2
Update firedrake/adjoint/all_at_once_reduced_functional.py
JHopeCollins Oct 24, 2024
afa377c
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Nov 4, 2024
948c818
aaorf - swap firedrake specific functions for pyadjoint ones
JHopeCollins Nov 4, 2024
89a10dc
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Nov 20, 2024
119de1a
Merge branch 'master' into allatoncereducedfunctional
JHopeCollins Nov 29, 2024
886f831
Merge remote-tracking branch 'origin' into allatoncereducedfunctional
JHopeCollins Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions firedrake/adjoint/all_at_once_reduced_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from pyadjoint import ReducedFunctional, stop_annotating, Control, \
overloaded_type
from pyadjoint.enlisting import Enlist
from pyop2.utils import cached_property
from firedrake import assemble, inner, dx
from functools import wraps

__all__ = ['AllAtOnceReducedFunctional']


def sc_passthrough(func):
"""
Wraps standard ReducedFunctional methods to differentiate strong or
weak constraint implementations.

If using strong constraint, makes sure strong_reduced_functional
is instantiated then passes args/kwargs through to the
corresponding strong_reduced_functional method.

If using weak constraint, returns the AllAtOnceReducedFunctional
method definition.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.weak_constraint:
return func(self, *args, **kwargs)
else:
sc_func = getattr(self.strong_reduced_functional, func.__name__)
return sc_func(*args, **kwargs)
return wrapper


def l2prod(x):
return assemble(inner(x, x)*dx)


class AllAtOnceReducedFunctional(ReducedFunctional):
"""ReducedFunctional for 4DVar data assimilation.

Creates either the strong constraint or weak constraint system incrementally
by logging observations through the initial forward model run.

Warning: Weak constraint 4DVar not implemented yet.

Parameters
----------
control : pyadjoint.Control
The initial condition :math:`x_{0}`. Starting value is used as the
background (prior) data :math:`x_{b}`.
background_iprod : Callable[pyadjoint.OverloadedType], optional
The inner product to calculate the background error functional from the
background error :math:`x_{0} - x_{b}`. Can include the error covariance matrix.
Defaults to L2.
observation_err : Callable[pyadjoint.OverloadedType], optional
Given a state :math:`x`, returns the observations error
:math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the observations at
the initial time and :math:`\\mathcal{H}_{0}` is the observation operator for
the initial time.
observation_iprod : Callable[pyadjoint.OverloadedType], optional
The inner product to calculate the observation error functional from the
observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. Can include the error
covariance matrix. Ignored if observation_err not provided.
Defaults to L2.
weak_constraint : bool
Whether to use the weak or strong constraint 4DVar formulation.

See Also
--------
:class:`pyadjoint.ReducedFunctional`.
"""

def __init__(self, control: Control, background_iprod=None,
observation_err=None, observation_iprod=None,
weak_constraint=True):
self.weak_constraint = weak_constraint
self.initial_observations = observation_err is not None

# default to l2 inner products for all functionals
background_iprod = background_iprod or l2prod
observation_iprod = observation_iprod or l2prod

# We need a copy for the prior, but this shouldn't be part of the tape
with stop_annotating():
self.background = control.copy_data()

if self.weak_constraint:
self.background_iprod = background_iprod # Inner product for background error (possibly including error covariance)
self.controls = [control] # The solution at the beginning of each time-chunk
self.states = [] # The model propogation at the end of each time-chunk
self.forward_model_stages = [] # ReducedFunctional for each model propogation (returns state)
self.forward_model_iprods = [] # Inner product for model errors (possibly including error covariance)
self.observations = [] # ReducedFunctional for each observation set (returns observation error)
self.observation_iprods = [] # Inner product for observation errors (possibly including error covariance)

if self.initial_observations:
self.observations.append(
ReducedFunctional(observation_err(control.control), control))
self.observation_iprods.append(observation_iprod)

else:
# initial conditions guess to be updated
self.controls = Enlist(control)

# Strong constraint functional to be converted to ReducedFunctional later

# penalty for straying from prior
self.functional = background_iprod(control.control - self.background)

# penalty for not hitting observations at initial time
if self.initial_observations:
self.functional += observation_iprod(observation_err(control.control))

def set_observation(self, state: overloaded_type, observation_err,
observation_iprod=None, forward_model_iprod=None):
"""
Record an observation at the time of `state`.

Parameters
----------

state: pyadjoint.OverloadedType
The state at the current observation time.
observation_err : Callable[pyadjoint.OverloadedType], optional
Given a state :math:`x`, returns the observations error
:math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are the observations
at the current observation time and :math:`\\mathcal{H}_{i}` is the
observation operator for the current observation time.
observation_iprod : Callable[pyadjoint.OverloadedType], optional
The inner product to calculate the observation error functional from the
observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. Can include the error
covariance matrix.
Defaults to L2.
forward_model_iprod : Callable[pyadjoint.OverloadedType], optional
The inner product to calculate the model error functional from the
model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})` where :math:`x_{i}`
is the state at the current observation time, :math:`x_{i-1}` is the state
at the previous observation time, and :math:`\\mathcal{M}_{i} is the forward
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved
model from the previous observation time. Can include the error covariance
matrix. Ignored if using the strong constraint formulation.
Defaults to L2.
"""
observation_iprod = observation_iprod or l2prod
if self.weak_constraint:

forward_model_iprod = forward_model_iprod or l2prod

self.states.append(state.block_variable)
self.forward_model_iprods.append(forward_model_iprod)

# Cut the tape into seperate time-chunks.
# State is output from previous control i.e. forward model
# propogation over previous time-chunk.
with stop_annotating(modifies=state):
self.forward_model_stages.append(
ReducedFunctional(state,
controls=self.controls[-1])
)
# Beginning of next time-chunk is the control for this observation
# and the state at the end of the next time-chunk.
next_control = Control(state)
self.controls.append(next_control)

# Observations after tape cut because this is now a control, not a state
self.observations.append(
ReducedFunctional(observation_err(state), next_control)
)
self.observation_iprods.append(observation_iprod)

else:

if hasattr(self, "_strong_reduced_functional"):
msg = "Cannot add observations once strong constraint ReducedFunctional instantiated"
raise ValueError(msg)
self.functional += observation_iprod(observation_err(state))

@cached_property
def strong_reduced_functional(self):
if self.weak_constraint:
msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar"
raise AttributeError(msg)
self._strong_reduced_functional = ReducedFunctional(
self.functional, self.controls)
return self._strong_reduced_functional

@sc_passthrough
def __call__(self, control_value):
# update controls so derivative etc is evaluated at correct point
for old, new in zip(self.controls, control_value):
old.update(new)

controls = self.controls

# Shift lists so indexing matches standard nomenclature:
# index 0 is initial condition. Model i propogates from i-1 to i.

forward_models = [None, *self.forward_model_stages]
model_iprods = [None, *self.forward_model_iprods]

observations = (self.observations if self.initial_observations
else [None, *self.observations])
observation_iprods = (self.observation_iprods if self.initial_observations
else [None, *self.observation_iprods])

# Initial condition functionals
J = self.background_iprod(controls[0].control - self.background)

if self.initial_observations:
J += observation_iprods[0](observations[0](controls[0]))

for i in range(1, len(forward_models)):
# Propogate forward over previous time-chunk
end_state = forward_models[i](controls[i-1])

# Cache end state here so we can reuse it in other functions
self.states[i-1] = end_state.block_variable

# Model error - does propogation from last control match this control?
model_err = end_state - controls[i].control
J += model_iprods[i](model_err)

# observation error - do we match the 'real world'?
obs_err = observations[i](controls[i])
J += observation_iprods[i](obs_err)

return J

@sc_passthrough
def derivative(self, *args, **kwargs):
# All the magic goes here.
raise NotImplementedError

@sc_passthrough
def hessian(self, *args, **kwargs):
raise NotImplementedError

def hessian_matrix(self):
# Other reduced functionals don't have this.
if not self.weak_constraint:
raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix")
raise NotImplementedError

@sc_passthrough
def optimize_tape(self):
raise NotImplementedError
Loading