diff --git a/CHANGES.rst b/CHANGES.rst index 126cdb7c54..cc22d9c867 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,7 +14,7 @@ New builtins * ``Series``, ``O`` and ``SeriesData`` * ``StringReverse`` * Add all of the named colors, e.g. ``Brown`` or ``LighterMagenta``. - +* ``Collect`` Enhancements diff --git a/mathics/builtin/algebra.py b/mathics/builtin/algebra.py index 81929074dc..72d0c6af9f 100644 --- a/mathics/builtin/algebra.py +++ b/mathics/builtin/algebra.py @@ -10,7 +10,9 @@ Atom, Expression, Integer, + Integer0, Integer1, + RationalOneHalf, Number, Symbol, SymbolFalse, @@ -18,6 +20,7 @@ SymbolTrue, ) from mathics.core.convert import from_sympy, sympy_symbol_prefix +from mathics.core.rules import Pattern import sympy @@ -62,7 +65,6 @@ def _expand(expr): if kwargs["modulus"] is not None and kwargs["modulus"] <= 0: return Integer(0) - # A special case for trigonometric functions if "trig" in kwargs and kwargs["trig"]: if expr.has_form("Sin", 1): @@ -149,7 +151,6 @@ def unconvert_subexprs(expr): ) sympy_expr = convert_sympy(expr) - if deep: # thread over everything for (i, sub_expr,) in enumerate(sub_exprs): @@ -192,7 +193,6 @@ def unconvert_subexprs(expr): sympy_expr = sympy_expr.expand(**hints) result = from_sympy(sympy_expr) result = unconvert_subexprs(result) - return result @@ -1413,3 +1413,202 @@ def apply(self, expr, form, h, evaluation): return Expression( "List", *[Expression(h, *[i for i in s]) for s in exponents] ) + + +class Collect(Builtin): + """ +
+
'Collect[$expr$, $x$]' +
Expands $expr$ and collect together terms having the same power of $x$. +
'Collect[$expr$, {$x_1$, $x_2$, ...}]' +
Expands $expr$ and collect together terms having the same powers of + $x_1$, $x_2$, .... +
'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]' +
After collect the terms, applies $filter$ to each coefficient. +
+ + >> Collect[(x+y)^3, y] + = x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3 + >> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y] + = 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z] + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y] + = 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3 + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}] + = 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3 + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h] + = x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1] + """ + + rules = { + "Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]", + } + + def apply_var_filter(self, expr, varlst, filt, evaluation): + """Collect[expr_, varlst_, filt_]""" + from mathics.builtin.patterns import match + + if varlst.is_symbol(): + var_exprs = [varlst] + elif varlst.has_form("List", None): + var_exprs = varlst.get_leaves() + else: + var_exprs = [varlst] + + if len(var_exprs) > 1: + target_pat = Pattern.create(Expression("Alternatives", *var_exprs)) + var_pats = [Pattern.create(var) for var in var_exprs] + else: + target_pat = Pattern.create(varlst) + var_pats = [target_pat] + + expr = expand( + expr, + numer=True, + denom=False, + deep=False, + trig=False, + modulus=None, + target_pat=target_pat, + ) + if filt == Symbol("Identity"): + filt = None + + def key_powers(lst): + key = Expression("Plus", *lst) + key = key.evaluate(evaluation) + if key.is_numeric(): + return key.to_python() + return 0 + + def powers_list(pf): + powers = [Integer0 for i, p in enumerate(var_pats)] + if pf is None: + return powers + if pf.is_symbol(): + for i, pat in enumerate(var_pats): + if match(pf, pat, evaluation): + powers[i] = Integer(1) + return powers + if pf.has_form("Sqrt", 1): + for i, pat in enumerate(var_pats): + if match(pf._leaves[0], pat, evaluation): + powers[i] = RationalOneHalf + return powers + if pf.has_form("Power", 2): + for i, pat in enumerate(var_pats): + matchval = match(pf._leaves[0], pat, evaluation) + if matchval: + powers[i] = pf._leaves[1] + return powers + if pf.has_form("Times", None): + contrib = [powers_list(factor) for factor in pf._leaves] + for i in range(len(var_pats)): + powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate( + evaluation + ) + return powers + return powers + + def split_coeff_pow(term: Expression): + """ + This function factorizes term in a coefficent free + of powers of the target variables, and a factor with + that powers. + """ + coeffs = [] + powers = [] + # First, split factors on those which are powers of the variables + # and the rest. + if term.is_free(target_pat, evaluation): + coeffs.append(term) + elif ( + term.is_symbol() + or term.has_form("Power", 2) + or term.has_form("Sqrt", 1) + ): + powers.append(term) + elif term.has_form("Times", None): + for factor in term.leaves: + if factor.is_free(target_pat, evaluation): + coeffs.append(factor) + elif match(factor, target_pat, evaluation): + powers.append(factor) + elif ( + factor.has_form("Power", 2) or factor.has_form("Sqrt", 1) + ) and match(factor._leaves[0], target_pat, evaluation): + powers.append(factor) + else: + coeffs.append(factor) + else: + coeffs.append(term) + # Now, rebuild both factors + if len(coeffs) == 0: + coeffs = None + elif len(coeffs) == 1: + coeffs = coeffs[0] + else: + coeffs = Expression("Times", *coeffs) + if len(powers) == 0: + powers = None + elif len(powers) == 1: + powers = powers[0] + else: + powers = Expression("Times", *sorted(powers)) + return coeffs, powers + + if expr.is_free(target_pat, evaluation): + if filt: + return Expression(filt, expr).evaluate(evaluation) + else: + return expr + elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1): + if filt: + return Expression( + "Times", Expression(filt, Integer1).evaluate(evaluation), expr + ) + else: + return expr + elif expr.has_form("Plus", None): + coeff_dict = {} + powers_dict = {} + powers_order = {} + for term in expr._leaves: + coeff, powers = split_coeff_pow(term) + pl = powers_list(powers) + key = str(pl) + if not key in powers_dict: + powers_dict[key] = powers + coeff_dict[key] = [] + powers_order[key] = key_powers(pl) + + coeff_dict[key].append(Integer1 if coeff is None else coeff) + + terms = [] + for key in sorted( + coeff_dict, key=lambda kv: powers_order[kv], reverse=False + ): + val = coeff_dict[key] + if len(val) == 0: + continue + elif len(val) == 1: + coeff = val[0] + else: + coeff = Expression("Plus", *val) + if filt: + coeff = Expression(filt, coeff).evaluate(evaluation) + + powerfactor = powers_dict[key] + if powerfactor: + terms.append(Expression("Times", coeff, powerfactor)) + else: + terms.append(coeff) + + return Expression("Plus", *terms) + else: + if filt: + return Expression(filt, expr).evaluate(evaluation) + else: + return expr + + +# tejimeto diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py index 92e7a8e76f..d1fa852eba 100644 --- a/mathics/builtin/patterns.py +++ b/mathics/builtin/patterns.py @@ -630,7 +630,10 @@ class _StopGeneratorMatchQ(StopGenerator): class Matcher(object): def __init__(self, form): - self.form = Pattern.create(form) + if isinstance(form, Pattern): + self.form = form + else: + self.form = Pattern.create(form) def match(self, expr, evaluation): def yield_func(vars, rest): diff --git a/mathics/core/expression.py b/mathics/core/expression.py index b93373a801..95007440fc 100644 --- a/mathics/core/expression.py +++ b/mathics/core/expression.py @@ -2258,10 +2258,9 @@ def __neg__(self) -> "Integer": def is_zero(self) -> bool: return self.value == 0 - +Integer0 = Integer(0) Integer1 = Integer(1) - class Rational(Number): @lru_cache() def __new__(cls, numerator, denominator=1) -> "Rational": @@ -2355,6 +2354,7 @@ def is_zero(self) -> bool: self.numerator().is_zero ) # (implicit) and not (self.denominator().is_zero) +RationalOneHalf = Rational(1, 2) class Real(Number): def __new__(cls, value, p=None) -> "Real":