diff --git a/CHANGES.rst b/CHANGES.rst
index 126cdb7c54..cc22d9c867 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -14,7 +14,7 @@ New builtins
* ``Series``, ``O`` and ``SeriesData``
* ``StringReverse``
* Add all of the named colors, e.g. ``Brown`` or ``LighterMagenta``.
-
+* ``Collect``
Enhancements
diff --git a/mathics/builtin/algebra.py b/mathics/builtin/algebra.py
index 81929074dc..72d0c6af9f 100644
--- a/mathics/builtin/algebra.py
+++ b/mathics/builtin/algebra.py
@@ -10,7 +10,9 @@
Atom,
Expression,
Integer,
+ Integer0,
Integer1,
+ RationalOneHalf,
Number,
Symbol,
SymbolFalse,
@@ -18,6 +20,7 @@
SymbolTrue,
)
from mathics.core.convert import from_sympy, sympy_symbol_prefix
+from mathics.core.rules import Pattern
import sympy
@@ -62,7 +65,6 @@ def _expand(expr):
if kwargs["modulus"] is not None and kwargs["modulus"] <= 0:
return Integer(0)
-
# A special case for trigonometric functions
if "trig" in kwargs and kwargs["trig"]:
if expr.has_form("Sin", 1):
@@ -149,7 +151,6 @@ def unconvert_subexprs(expr):
)
sympy_expr = convert_sympy(expr)
-
if deep:
# thread over everything
for (i, sub_expr,) in enumerate(sub_exprs):
@@ -192,7 +193,6 @@ def unconvert_subexprs(expr):
sympy_expr = sympy_expr.expand(**hints)
result = from_sympy(sympy_expr)
result = unconvert_subexprs(result)
-
return result
@@ -1413,3 +1413,202 @@ def apply(self, expr, form, h, evaluation):
return Expression(
"List", *[Expression(h, *[i for i in s]) for s in exponents]
)
+
+
+class Collect(Builtin):
+ """
+
+ - 'Collect[$expr$, $x$]'
+
- Expands $expr$ and collect together terms having the same power of $x$.
+
- 'Collect[$expr$, {$x_1$, $x_2$, ...}]'
+
- Expands $expr$ and collect together terms having the same powers of
+ $x_1$, $x_2$, ....
+
- 'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]'
+
- After collect the terms, applies $filter$ to each coefficient.
+
+
+ >> Collect[(x+y)^3, y]
+ = x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3
+ >> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y]
+ = 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z]
+ >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y]
+ = 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3
+ >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}]
+ = 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3
+ >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h]
+ = x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1]
+ """
+
+ rules = {
+ "Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]",
+ }
+
+ def apply_var_filter(self, expr, varlst, filt, evaluation):
+ """Collect[expr_, varlst_, filt_]"""
+ from mathics.builtin.patterns import match
+
+ if varlst.is_symbol():
+ var_exprs = [varlst]
+ elif varlst.has_form("List", None):
+ var_exprs = varlst.get_leaves()
+ else:
+ var_exprs = [varlst]
+
+ if len(var_exprs) > 1:
+ target_pat = Pattern.create(Expression("Alternatives", *var_exprs))
+ var_pats = [Pattern.create(var) for var in var_exprs]
+ else:
+ target_pat = Pattern.create(varlst)
+ var_pats = [target_pat]
+
+ expr = expand(
+ expr,
+ numer=True,
+ denom=False,
+ deep=False,
+ trig=False,
+ modulus=None,
+ target_pat=target_pat,
+ )
+ if filt == Symbol("Identity"):
+ filt = None
+
+ def key_powers(lst):
+ key = Expression("Plus", *lst)
+ key = key.evaluate(evaluation)
+ if key.is_numeric():
+ return key.to_python()
+ return 0
+
+ def powers_list(pf):
+ powers = [Integer0 for i, p in enumerate(var_pats)]
+ if pf is None:
+ return powers
+ if pf.is_symbol():
+ for i, pat in enumerate(var_pats):
+ if match(pf, pat, evaluation):
+ powers[i] = Integer(1)
+ return powers
+ if pf.has_form("Sqrt", 1):
+ for i, pat in enumerate(var_pats):
+ if match(pf._leaves[0], pat, evaluation):
+ powers[i] = RationalOneHalf
+ return powers
+ if pf.has_form("Power", 2):
+ for i, pat in enumerate(var_pats):
+ matchval = match(pf._leaves[0], pat, evaluation)
+ if matchval:
+ powers[i] = pf._leaves[1]
+ return powers
+ if pf.has_form("Times", None):
+ contrib = [powers_list(factor) for factor in pf._leaves]
+ for i in range(len(var_pats)):
+ powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate(
+ evaluation
+ )
+ return powers
+ return powers
+
+ def split_coeff_pow(term: Expression):
+ """
+ This function factorizes term in a coefficent free
+ of powers of the target variables, and a factor with
+ that powers.
+ """
+ coeffs = []
+ powers = []
+ # First, split factors on those which are powers of the variables
+ # and the rest.
+ if term.is_free(target_pat, evaluation):
+ coeffs.append(term)
+ elif (
+ term.is_symbol()
+ or term.has_form("Power", 2)
+ or term.has_form("Sqrt", 1)
+ ):
+ powers.append(term)
+ elif term.has_form("Times", None):
+ for factor in term.leaves:
+ if factor.is_free(target_pat, evaluation):
+ coeffs.append(factor)
+ elif match(factor, target_pat, evaluation):
+ powers.append(factor)
+ elif (
+ factor.has_form("Power", 2) or factor.has_form("Sqrt", 1)
+ ) and match(factor._leaves[0], target_pat, evaluation):
+ powers.append(factor)
+ else:
+ coeffs.append(factor)
+ else:
+ coeffs.append(term)
+ # Now, rebuild both factors
+ if len(coeffs) == 0:
+ coeffs = None
+ elif len(coeffs) == 1:
+ coeffs = coeffs[0]
+ else:
+ coeffs = Expression("Times", *coeffs)
+ if len(powers) == 0:
+ powers = None
+ elif len(powers) == 1:
+ powers = powers[0]
+ else:
+ powers = Expression("Times", *sorted(powers))
+ return coeffs, powers
+
+ if expr.is_free(target_pat, evaluation):
+ if filt:
+ return Expression(filt, expr).evaluate(evaluation)
+ else:
+ return expr
+ elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1):
+ if filt:
+ return Expression(
+ "Times", Expression(filt, Integer1).evaluate(evaluation), expr
+ )
+ else:
+ return expr
+ elif expr.has_form("Plus", None):
+ coeff_dict = {}
+ powers_dict = {}
+ powers_order = {}
+ for term in expr._leaves:
+ coeff, powers = split_coeff_pow(term)
+ pl = powers_list(powers)
+ key = str(pl)
+ if not key in powers_dict:
+ powers_dict[key] = powers
+ coeff_dict[key] = []
+ powers_order[key] = key_powers(pl)
+
+ coeff_dict[key].append(Integer1 if coeff is None else coeff)
+
+ terms = []
+ for key in sorted(
+ coeff_dict, key=lambda kv: powers_order[kv], reverse=False
+ ):
+ val = coeff_dict[key]
+ if len(val) == 0:
+ continue
+ elif len(val) == 1:
+ coeff = val[0]
+ else:
+ coeff = Expression("Plus", *val)
+ if filt:
+ coeff = Expression(filt, coeff).evaluate(evaluation)
+
+ powerfactor = powers_dict[key]
+ if powerfactor:
+ terms.append(Expression("Times", coeff, powerfactor))
+ else:
+ terms.append(coeff)
+
+ return Expression("Plus", *terms)
+ else:
+ if filt:
+ return Expression(filt, expr).evaluate(evaluation)
+ else:
+ return expr
+
+
+# tejimeto
diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py
index 92e7a8e76f..d1fa852eba 100644
--- a/mathics/builtin/patterns.py
+++ b/mathics/builtin/patterns.py
@@ -630,7 +630,10 @@ class _StopGeneratorMatchQ(StopGenerator):
class Matcher(object):
def __init__(self, form):
- self.form = Pattern.create(form)
+ if isinstance(form, Pattern):
+ self.form = form
+ else:
+ self.form = Pattern.create(form)
def match(self, expr, evaluation):
def yield_func(vars, rest):
diff --git a/mathics/core/expression.py b/mathics/core/expression.py
index b93373a801..95007440fc 100644
--- a/mathics/core/expression.py
+++ b/mathics/core/expression.py
@@ -2258,10 +2258,9 @@ def __neg__(self) -> "Integer":
def is_zero(self) -> bool:
return self.value == 0
-
+Integer0 = Integer(0)
Integer1 = Integer(1)
-
class Rational(Number):
@lru_cache()
def __new__(cls, numerator, denominator=1) -> "Rational":
@@ -2355,6 +2354,7 @@ def is_zero(self) -> bool:
self.numerator().is_zero
) # (implicit) and not (self.denominator().is_zero)
+RationalOneHalf = Rational(1, 2)
class Real(Number):
def __new__(cls, value, p=None) -> "Real":