Skip to content

Commit

Permalink
Implement Collect
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Apr 26, 2021
1 parent 7196a95 commit 3a78500
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ New builtins
* ``Series``, ``O`` and ``SeriesData``
* ``StringReverse``
* Add all of the named colors, e.g. ``Brown`` or ``LighterMagenta``.

* ``Collect``


Enhancements
Expand Down
205 changes: 202 additions & 3 deletions mathics/builtin/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
Atom,
Expression,
Integer,
Integer0,
Integer1,
RationalOneHalf,
Number,
Symbol,
SymbolFalse,
SymbolNull,
SymbolTrue,
)
from mathics.core.convert import from_sympy, sympy_symbol_prefix
from mathics.core.rules import Pattern

import sympy

Expand Down Expand Up @@ -62,7 +65,6 @@ def _expand(expr):

if kwargs["modulus"] is not None and kwargs["modulus"] <= 0:
return Integer(0)

This comment has been minimized.

Copy link
@rocky

rocky Apr 26, 2021

Member

Integer0 now. (There were a number of places in my recent pull requests where I knew I could have done this, but due to the order of merges Integer0 wasn't around.

This comment has been minimized.

Copy link
@rocky

rocky Apr 26, 2021

Member

Also I should say that when I tried to do this across the board a while back in adding Integer1 there was a problem, so it might not be able to be done always, without some sort of extra care.


# A special case for trigonometric functions
if "trig" in kwargs and kwargs["trig"]:
if expr.has_form("Sin", 1):
Expand Down Expand Up @@ -149,7 +151,6 @@ def unconvert_subexprs(expr):
)

sympy_expr = convert_sympy(expr)

if deep:
# thread over everything
for (i, sub_expr,) in enumerate(sub_exprs):
Expand Down Expand Up @@ -192,7 +193,6 @@ def unconvert_subexprs(expr):
sympy_expr = sympy_expr.expand(**hints)
result = from_sympy(sympy_expr)
result = unconvert_subexprs(result)

return result


Expand Down Expand Up @@ -1413,3 +1413,202 @@ def apply(self, expr, form, h, evaluation):
return Expression(
"List", *[Expression(h, *[i for i in s]) for s in exponents]
)


class Collect(Builtin):
"""
<dl>
<dt>'Collect[$expr$, $x$]'
<dd> Expands $expr$ and collect together terms having the same power of $x$.
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]'
<dd> Expands $expr$ and collect together terms having the same powers of
$x_1$, $x_2$, ....
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]'
<dd> After collect the terms, applies $filter$ to each coefficient.
</dl>
>> Collect[(x+y)^3, y]
= x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3
>> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y]
= 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z]
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y]
= 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}]
= 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h]
= x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1]
"""

rules = {
"Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]",
}

def apply_var_filter(self, expr, varlst, filt, evaluation):
"""Collect[expr_, varlst_, filt_]"""
from mathics.builtin.patterns import match

if varlst.is_symbol():
var_exprs = [varlst]
elif varlst.has_form("List", None):
var_exprs = varlst.get_leaves()
else:
var_exprs = [varlst]

if len(var_exprs) > 1:
target_pat = Pattern.create(Expression("Alternatives", *var_exprs))
var_pats = [Pattern.create(var) for var in var_exprs]
else:
target_pat = Pattern.create(varlst)
var_pats = [target_pat]

expr = expand(
expr,
numer=True,
denom=False,
deep=False,
trig=False,
modulus=None,
target_pat=target_pat,
)
if filt == Symbol("Identity"):
filt = None

def key_powers(lst):
key = Expression("Plus", *lst)
key = key.evaluate(evaluation)
if key.is_numeric():
return key.to_python()
return 0

def powers_list(pf):
powers = [Integer0 for i, p in enumerate(var_pats)]
if pf is None:
return powers
if pf.is_symbol():
for i, pat in enumerate(var_pats):
if match(pf, pat, evaluation):
powers[i] = Integer(1)
return powers
if pf.has_form("Sqrt", 1):
for i, pat in enumerate(var_pats):
if match(pf._leaves[0], pat, evaluation):
powers[i] = RationalOneHalf
return powers
if pf.has_form("Power", 2):
for i, pat in enumerate(var_pats):
matchval = match(pf._leaves[0], pat, evaluation)
if matchval:
powers[i] = pf._leaves[1]
return powers
if pf.has_form("Times", None):
contrib = [powers_list(factor) for factor in pf._leaves]
for i in range(len(var_pats)):
powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate(
evaluation
)
return powers
return powers

def split_coeff_pow(term: Expression):
"""
This function factorizes term in a coefficent free
of powers of the target variables, and a factor with
that powers.
"""
coeffs = []
powers = []
# First, split factors on those which are powers of the variables
# and the rest.
if term.is_free(target_pat, evaluation):
coeffs.append(term)
elif (
term.is_symbol()
or term.has_form("Power", 2)
or term.has_form("Sqrt", 1)
):
powers.append(term)
elif term.has_form("Times", None):
for factor in term.leaves:
if factor.is_free(target_pat, evaluation):
coeffs.append(factor)
elif match(factor, target_pat, evaluation):
powers.append(factor)
elif (
factor.has_form("Power", 2) or factor.has_form("Sqrt", 1)
) and match(factor._leaves[0], target_pat, evaluation):
powers.append(factor)
else:
coeffs.append(factor)
else:
coeffs.append(term)
# Now, rebuild both factors
if len(coeffs) == 0:
coeffs = None
elif len(coeffs) == 1:
coeffs = coeffs[0]
else:
coeffs = Expression("Times", *coeffs)
if len(powers) == 0:
powers = None
elif len(powers) == 1:
powers = powers[0]
else:
powers = Expression("Times", *sorted(powers))
return coeffs, powers

if expr.is_free(target_pat, evaluation):
if filt:
return Expression(filt, expr).evaluate(evaluation)
else:
return expr
elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1):
if filt:
return Expression(
"Times", Expression(filt, Integer1).evaluate(evaluation), expr
)
else:
return expr
elif expr.has_form("Plus", None):
coeff_dict = {}
powers_dict = {}
powers_order = {}
for term in expr._leaves:
coeff, powers = split_coeff_pow(term)
pl = powers_list(powers)
key = str(pl)
if not key in powers_dict:
powers_dict[key] = powers
coeff_dict[key] = []
powers_order[key] = key_powers(pl)

coeff_dict[key].append(Integer1 if coeff is None else coeff)

terms = []
for key in sorted(
coeff_dict, key=lambda kv: powers_order[kv], reverse=False
):
val = coeff_dict[key]
if len(val) == 0:
continue
elif len(val) == 1:
coeff = val[0]
else:
coeff = Expression("Plus", *val)
if filt:
coeff = Expression(filt, coeff).evaluate(evaluation)

powerfactor = powers_dict[key]
if powerfactor:
terms.append(Expression("Times", coeff, powerfactor))
else:
terms.append(coeff)

return Expression("Plus", *terms)
else:
if filt:
return Expression(filt, expr).evaluate(evaluation)
else:
return expr


# tejimeto
5 changes: 4 additions & 1 deletion mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ class _StopGeneratorMatchQ(StopGenerator):

class Matcher(object):
def __init__(self, form):
self.form = Pattern.create(form)
if isinstance(form, Pattern):
self.form = form
else:
self.form = Pattern.create(form)

def match(self, expr, evaluation):
def yield_func(vars, rest):
Expand Down
4 changes: 2 additions & 2 deletions mathics/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,10 +2258,9 @@ def __neg__(self) -> "Integer":
def is_zero(self) -> bool:
return self.value == 0


Integer0 = Integer(0)
Integer1 = Integer(1)


class Rational(Number):
@lru_cache()
def __new__(cls, numerator, denominator=1) -> "Rational":
Expand Down Expand Up @@ -2355,6 +2354,7 @@ def is_zero(self) -> bool:
self.numerator().is_zero
) # (implicit) and not (self.denominator().is_zero)

RationalOneHalf = Rational(1, 2)

class Real(Number):
def __new__(cls, value, p=None) -> "Real":
Expand Down

0 comments on commit 3a78500

Please sign in to comment.