Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
mamanain committed Mar 8, 2025
1 parent 805ea3e commit 0d18d91
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 10 deletions.
81 changes: 81 additions & 0 deletions tests/filecheck/transforms/ptr_loop_folding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: xdsl-opt -p convert-memref-to-ptr,convert-ptr-type-offsets,canonicalize,scf-for-loop-range-folding,canonicalize,scf-for-loop-flatten,canonicalize,scf-for-loop-range-folding,canonicalize %s

func.func @fill(%m: memref<10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
memref.store %val, %m[%i] : memref<10xi32>
}
return
}

// CHECK: func.func @fill(%m : memref<10xi32>) {
// CHECK-NEXT: %val = arith.constant 100 : i32
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 40 : index
// CHECK-NEXT: %2 = arith.constant 4 : index
// CHECK-NEXT: scf.for %i = %0 to %1 step %2 {
// CHECK-NEXT: %3 = ptr_xdsl.to_ptr %m : memref<10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer = ptr_xdsl.ptradd %3, %i : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %val, %offset_pointer : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }


func.func @fill2d(%m: memref<10x10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
scf.for %j = %c0 to %end step %c1 {
memref.store %val, %m[%i, %j] : memref<10x10xi32>
}
}
return
}

// CHECK-NEXT: func.func @fill2d(%m : memref<10x10xi32>) {
// CHECK-NEXT: %val = arith.constant 100 : i32
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 400 : index
// CHECK-NEXT: %2 = arith.constant 4 : index
// CHECK-NEXT: scf.for %j = %0 to %1 step %2 {
// CHECK-NEXT: %3 = ptr_xdsl.to_ptr %m : memref<10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer = ptr_xdsl.ptradd %3, %j : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %val, %offset_pointer : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }


func.func @fill3d(%m: memref<10x10x10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
scf.for %j = %c0 to %end step %c1 {
scf.for %k = %c0 to %end step %c1 {
memref.store %val, %m[%i, %j, %k] : memref<10x10x10xi32>
}
}
}
return
}

// CHECK-NEXT: func.func @fill3d(%m : memref<10x10x10xi32>) {
// CHECK-NEXT: %val = arith.constant 100 : i32
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 4000 : index
// CHECK-NEXT: %2 = arith.constant 4 : index
// CHECK-NEXT: scf.for %k = %0 to %1 step %2 {
// CHECK-NEXT: %3 = ptr_xdsl.to_ptr %m : memref<10x10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer = ptr_xdsl.ptradd %3, %k : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %val, %offset_pointer : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/scf_for_loop_range_folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ scf.for %j = %lb to %ub step %step {
"test.op"(%mult_idx) : (index) -> ()
}
}

// CHECK-NEXT: %15 = arith.muli %lb, %mul_shift : index
// CHECK-NEXT: %16 = arith.muli %ub, %mul_shift : index
// CHECK-NEXT: %17 = arith.muli %step, %mul_shift : index
// CHECK-NEXT: scf.for %i_3 = %15 to %16 step %17 {
// CHECK-NEXT: scf.for %j_1 = %lb to %ub step %step {
// CHECK-NEXT: %b = arith.addi %i_3, %j_1 : index
// CHECK-NEXT: %c = arith.muli %b, %mul_shift : index
// CHECK-NEXT: "test.op"(%c) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }

scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%a = arith.muli %i, %mul_shift : index
%b = arith.addi %a, %j : index
%c = arith.muli %b, %mul_shift : index
"test.op"(%c) : (index) -> ()
}
}
13 changes: 3 additions & 10 deletions xdsl/transforms/scf_for_loop_range_folding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xdsl.context import Context
from xdsl.dialects import arith, builtin, scf
from xdsl.ir import BlockArgument, OpResult, SSAValue
from xdsl.ir import SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand All @@ -11,13 +11,7 @@


def is_foldable(val: SSAValue, for_op: scf.ForOp):
if isinstance(val, BlockArgument):
return True

if not isinstance(val, OpResult):
return False

return not for_op.is_ancestor(val.op)
return not for_op.is_ancestor(val.owner)


class ScfForLoopRangeFolding(RewritePattern):
Expand Down Expand Up @@ -79,6 +73,5 @@ class ScfForLoopRangeFoldingPass(ModulePass):

def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
ScfForLoopRangeFolding(),
apply_recursively=True,
ScfForLoopRangeFolding(), apply_recursively=False, walk_regions_first=True
).rewrite_module(op)

0 comments on commit 0d18d91

Please sign in to comment.