Skip to content

The rewrite log1pmexp_to_log1mexp is not applied #1476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lciti opened this issue Jun 14, 2025 · 7 comments · May be fixed by #1483
Open

The rewrite log1pmexp_to_log1mexp is not applied #1476

lciti opened this issue Jun 14, 2025 · 7 comments · May be fixed by #1483
Labels
bug Something isn't working

Comments

@lciti
Copy link

lciti commented Jun 14, 2025

Describe the issue:

While investigating PR #1452 I run into the following issue.

The rewrite log1pmexp_to_log1mexp (part of the "stabilize" phase) is generally not applied because:

  • log1pmexp_to_log1mexp tries to match (log1p, (neg, (exp, "x"))), however
  • during the preceding "canonicalize" phase the neg is converted to -1.0* so it fails to the match above

This happens both when compiling and when applying rewrite_graph.

See the MRE below, which gives:

Log1p [id A]
 └─ Neg [id B]
    └─ Exp [id C]
       └─ x [id D]
====================
rewriting: rewrite local_neg_to_mul replaces Neg.0 of Neg(Exp.0) with Mul.0 of Mul(-1.0, Exp.0)
rewriting: rewrite local_mul_specialize replaces Mul.0 of Mul(-1.0, Exp.0) with Neg.0 of Neg(Exp.0)
====================
Log1p [id A]
 └─ Neg [id B]
    └─ Exp [id C]
       └─ x [id D]

If one removes the "canonicalize" step from rewrite_graph, then the rewrite is correctly applied:

Log1p [id A]
 └─ Neg [id B]
    └─ Exp [id C]
       └─ x [id D]
====================
rewriting: rewrite e(Log1p, e(Neg, e(Exp, ~x))) -> e(Scalar_log1mexp, ~x) replaces Log1p.0 of Log1p(Neg.0) with Scalar_log1mexp.0 of Scalar_log1mexp(x)
====================
Scalar_log1mexp [id A]
 └─ x [id B]

Reproducable code example:

import pytensor
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.scalar("x")
out = pt.log1p(-pt.exp(x))
pytensor.dprint(out)
print('='*20)
with pytensor.config.change_flags(optimizer_verbose = True):
    #fn = pytensor.function([x], out, mode="FAST_RUN")
    fn = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
print('='*20)
pytensor.dprint(fn);

Error message:

PyTensor version information:

PyTensor 2.31.3

Context for the issue:

No response

@lciti lciti added the bug Something isn't working label Jun 14, 2025
@lciti
Copy link
Author

lciti commented Jun 14, 2025

A possible workaround seems to be to register log1pmexp_to_log1mexp also as specialize:

register_specialize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")

but I have no idea if this causes troubles elsewhere.

@ricardoV94
Copy link
Member

If neg becomes -1 by the time stabilize comes we can change the rewrite to look for that form instead. I think the rewrite makes complete sense as stabilize

@lciti
Copy link
Author

lciti commented Jun 15, 2025

I think I see your point and I have some questions to understand the current implementation to better assist with fixing the bug (if I can).

If I understand correctly, "canonicalize" should bring everything to a common form, and the specific implementation is that neg(x) becomes mul(-1.0, x) through "local_neg_to_mul" and reciprocal(x) becomes True_div(1.0, x) through "local_mul_canonizer" (*). Then the "stabilize" step takes over, assuming this canonical form. Finally "specialize" looks at the final form and (among other things) it "cleans up" any remaining mul(-1.0, x) and True_div(1.0, x) and converts them back to neg(x) and reciprocal(x) (in "local_mul_specialize" and "local_div_to_reciprocal").
However, for this to work, "stabilize" steps should only operate from and to the canonical representation, otherwise some stabilize steps may be missed. Instead, some stabilize steps output a "neg(...)" (for example "logsigm_to_softplus"), which may prevent further stabilizes. As far as I can see, at the moment the only stabilize matching a neg is "log1pmexp_to_log1mexp" so this would be the only one that would require a double pattern match to work reliably (unless we manage to enforce that no "stabilize" outputs a non "canonical" representation).

Apologies for the long comment, I hope it makes sense and it can help clarify the issue and facilitate the debugging process. My goal is to understand the current implementation and provide assistance, not to criticize the architecture.

(*) However, this appears to be a bit "brittle" since this rewrite occurs because "local_mul_canonizer" happens to come before "local_reciprocal_canon". If one uses exclude=("local_mul_canonizer",) then reciprocal(x) becomes pow(x, -1.0) due to "local_reciprocal_canon". A consequence of this is that, for example, reciprocal(1 + exp(-x)) is not rewritten as a sigmoid(x).

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 15, 2025

The long comment is great and yes, you're touching on one of the weaknesses of the current system: rewrite ordering. If you're curious we are also exploring alternatives that are robust to rewrite ordering (at the cost of rewrite explosion xD) -egraphs:

I'll reply to your comment details tomorrow.

@lciti
Copy link
Author

lciti commented Jun 16, 2025

Thanks for the link. It's a really interesting approach.
But I guess fully switching to that approach is a major undertaking and will take time.

I think your current implementation partially address the issue of rewrite ordering by having separate stabilize and specialize steps, with the former happening first. So in a sense you make sure that more important rewrites happen first (or happen at all).
You partially address to problem of equivalent expressions by cononicalizing first so that each expression of the initial computation is replaced with the representative member of the equivalence class of expressions it belong to (e.g. [neg(x), mul(-1,x)]). Then stabilize steps should assume a canonicalized expression or the result of another stabilize (for example "log1pexp_to_softplus" matches the pattern (log1p, (exp, "x")), where "log1p" is the output of another stabilize rewrite "local_log1p").

Long story short, to fix this issue we may need to change "log1pmexp_to_log1mexp" to use the canonical form "mul(-1.0, exp(x))" (as you suggest) and also fix "logsigm_to_softplus" and "local_log1p" (and others if there are) to avoid producing a "neg" as output.
However, what about specializes? Should "expm1(log1mexp(x)) -> -exp(x)" within "local_exp_log_nan_switch" also avoid outputting a "neg"?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 16, 2025

But I guess fully switching to that approach is a major undertaking and will take time.

I wasn't suggesting we do it, just keeping in the back of the mind of how to bring some of those ideas. I shared because knowing an alternative can help gain intuition about the approach we've been taking and what the pain points are.

Long story short, to fix this issue we may need to change "log1pmexp_to_log1mexp" to use the canonical form "mul(-1.0, exp(x))" (as you suggest) and also fix "logsigm_to_softplus" and "local_log1p" (and others if there are) to avoid producing a "neg" as output.

Sounds like a plan

However, what about specializes? Should "expm1(log1mexp(x)) -> -exp(x)" within "local_exp_log_nan_switch" also avoid outputting a "neg"?

I think in specialize we convert -1 * x back to neg, but would need to check. Specialize is more about performance, but I doubt it ever matters much which form we use.

lciti pushed a commit to lciti/pytensor that referenced this issue Jun 17, 2025
@lciti
Copy link
Author

lciti commented Jun 17, 2025

I found the function is_neg and realised it was easier to just re-implement log1pmexp_to_log1mexp using that one. I have submitted #1483 using that approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants