-
-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
209 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,17 @@ | |
Atom, | ||
Expression, | ||
Integer, | ||
Integer0, | ||
Integer1, | ||
RationalOneHalf, | ||
Number, | ||
Symbol, | ||
SymbolFalse, | ||
SymbolNull, | ||
SymbolTrue, | ||
) | ||
from mathics.core.convert import from_sympy, sympy_symbol_prefix | ||
from mathics.core.rules import Pattern | ||
|
||
import sympy | ||
|
||
|
@@ -62,7 +65,6 @@ def _expand(expr): | |
|
||
if kwargs["modulus"] is not None and kwargs["modulus"] <= 0: | ||
return Integer(0) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
rocky
Member
|
||
|
||
# A special case for trigonometric functions | ||
if "trig" in kwargs and kwargs["trig"]: | ||
if expr.has_form("Sin", 1): | ||
|
@@ -149,7 +151,6 @@ def unconvert_subexprs(expr): | |
) | ||
|
||
sympy_expr = convert_sympy(expr) | ||
|
||
if deep: | ||
# thread over everything | ||
for (i, sub_expr,) in enumerate(sub_exprs): | ||
|
@@ -192,7 +193,6 @@ def unconvert_subexprs(expr): | |
sympy_expr = sympy_expr.expand(**hints) | ||
result = from_sympy(sympy_expr) | ||
result = unconvert_subexprs(result) | ||
|
||
return result | ||
|
||
|
||
|
@@ -1413,3 +1413,202 @@ def apply(self, expr, form, h, evaluation): | |
return Expression( | ||
"List", *[Expression(h, *[i for i in s]) for s in exponents] | ||
) | ||
|
||
|
||
class Collect(Builtin): | ||
""" | ||
<dl> | ||
<dt>'Collect[$expr$, $x$]' | ||
<dd> Expands $expr$ and collect together terms having the same power of $x$. | ||
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]' | ||
<dd> Expands $expr$ and collect together terms having the same powers of | ||
$x_1$, $x_2$, .... | ||
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]' | ||
<dd> After collect the terms, applies $filter$ to each coefficient. | ||
</dl> | ||
>> Collect[(x+y)^3, y] | ||
= x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3 | ||
>> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y] | ||
= 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z] | ||
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y] | ||
= 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3 | ||
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}] | ||
= 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3 | ||
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h] | ||
= x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1] | ||
""" | ||
|
||
rules = { | ||
"Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]", | ||
} | ||
|
||
def apply_var_filter(self, expr, varlst, filt, evaluation): | ||
"""Collect[expr_, varlst_, filt_]""" | ||
from mathics.builtin.patterns import match | ||
|
||
if varlst.is_symbol(): | ||
var_exprs = [varlst] | ||
elif varlst.has_form("List", None): | ||
var_exprs = varlst.get_leaves() | ||
else: | ||
var_exprs = [varlst] | ||
|
||
if len(var_exprs) > 1: | ||
target_pat = Pattern.create(Expression("Alternatives", *var_exprs)) | ||
var_pats = [Pattern.create(var) for var in var_exprs] | ||
else: | ||
target_pat = Pattern.create(varlst) | ||
var_pats = [target_pat] | ||
|
||
expr = expand( | ||
expr, | ||
numer=True, | ||
denom=False, | ||
deep=False, | ||
trig=False, | ||
modulus=None, | ||
target_pat=target_pat, | ||
) | ||
if filt == Symbol("Identity"): | ||
filt = None | ||
|
||
def key_powers(lst): | ||
key = Expression("Plus", *lst) | ||
key = key.evaluate(evaluation) | ||
if key.is_numeric(): | ||
return key.to_python() | ||
return 0 | ||
|
||
def powers_list(pf): | ||
powers = [Integer0 for i, p in enumerate(var_pats)] | ||
if pf is None: | ||
return powers | ||
if pf.is_symbol(): | ||
for i, pat in enumerate(var_pats): | ||
if match(pf, pat, evaluation): | ||
powers[i] = Integer(1) | ||
return powers | ||
if pf.has_form("Sqrt", 1): | ||
for i, pat in enumerate(var_pats): | ||
if match(pf._leaves[0], pat, evaluation): | ||
powers[i] = RationalOneHalf | ||
return powers | ||
if pf.has_form("Power", 2): | ||
for i, pat in enumerate(var_pats): | ||
matchval = match(pf._leaves[0], pat, evaluation) | ||
if matchval: | ||
powers[i] = pf._leaves[1] | ||
return powers | ||
if pf.has_form("Times", None): | ||
contrib = [powers_list(factor) for factor in pf._leaves] | ||
for i in range(len(var_pats)): | ||
powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate( | ||
evaluation | ||
) | ||
return powers | ||
return powers | ||
|
||
def split_coeff_pow(term: Expression): | ||
""" | ||
This function factorizes term in a coefficent free | ||
of powers of the target variables, and a factor with | ||
that powers. | ||
""" | ||
coeffs = [] | ||
powers = [] | ||
# First, split factors on those which are powers of the variables | ||
# and the rest. | ||
if term.is_free(target_pat, evaluation): | ||
coeffs.append(term) | ||
elif ( | ||
term.is_symbol() | ||
or term.has_form("Power", 2) | ||
or term.has_form("Sqrt", 1) | ||
): | ||
powers.append(term) | ||
elif term.has_form("Times", None): | ||
for factor in term.leaves: | ||
if factor.is_free(target_pat, evaluation): | ||
coeffs.append(factor) | ||
elif match(factor, target_pat, evaluation): | ||
powers.append(factor) | ||
elif ( | ||
factor.has_form("Power", 2) or factor.has_form("Sqrt", 1) | ||
) and match(factor._leaves[0], target_pat, evaluation): | ||
powers.append(factor) | ||
else: | ||
coeffs.append(factor) | ||
else: | ||
coeffs.append(term) | ||
# Now, rebuild both factors | ||
if len(coeffs) == 0: | ||
coeffs = None | ||
elif len(coeffs) == 1: | ||
coeffs = coeffs[0] | ||
else: | ||
coeffs = Expression("Times", *coeffs) | ||
if len(powers) == 0: | ||
powers = None | ||
elif len(powers) == 1: | ||
powers = powers[0] | ||
else: | ||
powers = Expression("Times", *sorted(powers)) | ||
return coeffs, powers | ||
|
||
if expr.is_free(target_pat, evaluation): | ||
if filt: | ||
return Expression(filt, expr).evaluate(evaluation) | ||
else: | ||
return expr | ||
elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1): | ||
if filt: | ||
return Expression( | ||
"Times", Expression(filt, Integer1).evaluate(evaluation), expr | ||
) | ||
else: | ||
return expr | ||
elif expr.has_form("Plus", None): | ||
coeff_dict = {} | ||
powers_dict = {} | ||
powers_order = {} | ||
for term in expr._leaves: | ||
coeff, powers = split_coeff_pow(term) | ||
pl = powers_list(powers) | ||
key = str(pl) | ||
if not key in powers_dict: | ||
powers_dict[key] = powers | ||
coeff_dict[key] = [] | ||
powers_order[key] = key_powers(pl) | ||
|
||
coeff_dict[key].append(Integer1 if coeff is None else coeff) | ||
|
||
terms = [] | ||
for key in sorted( | ||
coeff_dict, key=lambda kv: powers_order[kv], reverse=False | ||
): | ||
val = coeff_dict[key] | ||
if len(val) == 0: | ||
continue | ||
elif len(val) == 1: | ||
coeff = val[0] | ||
else: | ||
coeff = Expression("Plus", *val) | ||
if filt: | ||
coeff = Expression(filt, coeff).evaluate(evaluation) | ||
|
||
powerfactor = powers_dict[key] | ||
if powerfactor: | ||
terms.append(Expression("Times", coeff, powerfactor)) | ||
else: | ||
terms.append(coeff) | ||
|
||
return Expression("Plus", *terms) | ||
else: | ||
if filt: | ||
return Expression(filt, expr).evaluate(evaluation) | ||
else: | ||
return expr | ||
|
||
|
||
# tejimeto |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Integer0 now. (There were a number of places in my recent pull requests where I knew I could have done this, but due to the order of merges Integer0 wasn't around.