Skip to content

Commit

Permalink
Merge pull request #51 from firedrakeproject/dham/cofunction_is_terminal
Browse files Browse the repository at this point in the history
Make cofunctionals terminal, and test
  • Loading branch information
dham authored Jul 17, 2024
2 parents fbd288e + 5730858 commit c1a8afb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1
from ufl.exprcontainers import ExprList


def test_comparison_of_coefficients():
Expand Down Expand Up @@ -69,6 +70,10 @@ def test_comparison_of_cofunctions():
assert not v1 == u1
assert not v2 == u2

# Objects in ExprList as happens when taking derivatives.
assert ExprList(v1, v1) == ExprList(v1, v1b)
assert not ExprList(v1, v2) == ExprList(v1, v1)


def test_comparison_of_products():
V = FiniteElement("Lagrange", triangle, 1, (), identity_pullback, H1)
Expand Down
1 change: 1 addition & 0 deletions ufl/coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Cofunction(BaseCoefficient, BaseForm):
)
_primal = False
_dual = True
_ufl_is_terminal_ = True

__eq__ = BaseForm.__eq__

Expand Down
16 changes: 16 additions & 0 deletions ufl/formatting/ufl2unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,26 @@ def coefficient(self, o):
return f"{var}{subscript_number(i)}"
return self.coefficient_names[o.count()]

def cofunction(self, o):
"""Format a cofunction."""
if self.coefficient_names is None:
i = o.count()
var = "cofunction"
if len(o.ufl_shape) == 1:
var += UC.combining_right_arrow_above
elif len(o.ufl_shape) > 1 and self.colorama_bold:
var = f"{colorama.Style.BRIGHT}{var}{colorama.Style.RESET_ALL}"
return f"{var}{subscript_number(i)}"
return self.coefficient_names[o.count()]

def base_form_operator(self, o):
"""Format a base_form_operator."""
return "BaseFormOperator"

def action(self, o, a, b):
"""Format an Action."""
return f"Action({a}, {b})"

def constant(self, o):
"""Format a constant."""
i = o.count()
Expand Down

0 comments on commit c1a8afb

Please sign in to comment.