Skip to content

Commit

Permalink
transformations: Make loop folding consistent with mlir (#4037)
Browse files Browse the repository at this point in the history
There was a difference in logic between this pass in mlir and xdsl and
xdsl's version made it hard to flatten ptr loops. Made the logic
consistent.
  • Loading branch information
mamanain authored Mar 8, 2025
1 parent 4ea13ad commit abfb448
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/scf_for_loop_range_folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ scf.for %j = %lb to %ub step %step {
"test.op"(%mult_idx) : (index) -> ()
}
}

// CHECK-NEXT: %15 = arith.muli %lb, %mul_shift : index
// CHECK-NEXT: %16 = arith.muli %ub, %mul_shift : index
// CHECK-NEXT: %17 = arith.muli %step, %mul_shift : index
// CHECK-NEXT: scf.for %i_3 = %15 to %16 step %17 {
// CHECK-NEXT: scf.for %j_1 = %lb to %ub step %step {
// CHECK-NEXT: %b = arith.addi %i_3, %j_1 : index
// CHECK-NEXT: %c = arith.muli %b, %mul_shift : index
// CHECK-NEXT: "test.op"(%c) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }

scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%a = arith.muli %i, %mul_shift : index
%b = arith.addi %a, %j : index
%c = arith.muli %b, %mul_shift : index
"test.op"(%c) : (index) -> ()
}
}
13 changes: 3 additions & 10 deletions xdsl/transforms/scf_for_loop_range_folding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xdsl.context import Context
from xdsl.dialects import arith, builtin, scf
from xdsl.ir import BlockArgument, OpResult, SSAValue
from xdsl.ir import SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand All @@ -11,13 +11,7 @@


def is_foldable(val: SSAValue, for_op: scf.ForOp):
if isinstance(val, BlockArgument):
return True

if not isinstance(val, OpResult):
return False

return not for_op.is_ancestor(val.op)
return not for_op.is_ancestor(val.owner)


class ScfForLoopRangeFolding(RewritePattern):
Expand Down Expand Up @@ -79,6 +73,5 @@ class ScfForLoopRangeFoldingPass(ModulePass):

def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
ScfForLoopRangeFolding(),
apply_recursively=True,
ScfForLoopRangeFolding(), apply_recursively=False, walk_regions_first=True
).rewrite_module(op)

0 comments on commit abfb448

Please sign in to comment.