-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Arith to ModArith Conversion pass and Mac Transformer pass
- Loading branch information
Wouter Legiest
committed
Nov 28, 2024
1 parent
55c07f6
commit edf3902
Showing
30 changed files
with
551 additions
and
2,241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" | ||
|
||
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h" | ||
#include "lib/Dialect/ModArith/IR/ModArithOps.h" | ||
#include "lib/Dialect/ModArith/IR/ModArithTypes.h" | ||
#include "lib/Utils/ConversionUtils/ConversionUtils.h" | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "arith-to-mod-arith" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace mod_arith { | ||
|
||
#define GEN_PASS_DEF_ARITHTOMODARITH | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
static ModArithType convertArithType(Type type) { | ||
auto modulus = type.getIntOrFloatBitWidth(); | ||
return ModArithType::get(type.getContext(), | ||
mlir::IntegerAttr::get(type, modulus)); | ||
} | ||
|
||
static Type convertArithLikeType(ShapedType type) { | ||
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) { | ||
return type.cloneWith(type.getShape(), convertArithType(arithType)); | ||
} | ||
return type; | ||
} | ||
|
||
class ArithToModArithTypeConverter : public TypeConverter { | ||
public: | ||
ArithToModArithTypeConverter(MLIRContext *ctx) { | ||
addConversion([](Type type) { return type; }); | ||
addConversion([](IntegerType type) -> ModArithType { | ||
return convertArithType(type); | ||
}); | ||
addConversion( | ||
[](ShapedType type) -> Type { return convertArithLikeType(type); }); | ||
} | ||
}; | ||
|
||
struct ConvertConstant : public OpConversionPattern<::mlir::arith::ConstantOp> { | ||
ConvertConstant(mlir::MLIRContext *context) | ||
: OpConversionPattern<::mlir::arith::ConstantOp>(context) {} | ||
|
||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
::mlir::arith::ConstantOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto result = b.create<mod_arith::ConstantOp>(ModArithAttr::get( | ||
convertArithType(op.getType()), | ||
cast<IntegerAttr>(op.getValue()).getValue().getSExtValue())); | ||
|
||
rewriter.replaceOp(op, result); | ||
return success(); | ||
return success(); | ||
} | ||
}; | ||
|
||
template <typename SourceArithOp, typename TargetModArithOp> | ||
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> { | ||
ConvertBinOp(mlir::MLIRContext *context) | ||
: OpConversionPattern<SourceArithOp>(context) {} | ||
|
||
using OpConversionPattern<SourceArithOp>::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
SourceArithOp op, typename SourceArithOp::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto result = | ||
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs()); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ArithToModArith : impl::ArithToModArithBase<ArithToModArith> { | ||
using ArithToModArithBase::ArithToModArithBase; | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void ArithToModArith::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp module = getOperation(); | ||
ArithToModArithTypeConverter typeConverter(context); | ||
|
||
ConversionTarget target(*context); | ||
target.addLegalDialect<ModArithDialect>(); | ||
target.addIllegalDialect<arith::ArithDialect>(); | ||
|
||
RewritePatternSet patterns(context); | ||
patterns.add<ConvertConstant, ConvertBinOp<arith::AddIOp, AddOp>, | ||
ConvertBinOp<arith::SubIOp, SubOp>, | ||
ConvertBinOp<arith::MulIOp, MulOp>>(typeConverter, context); | ||
|
||
addStructuralConversionPatterns(typeConverter, patterns, target); | ||
|
||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace mod_arith | ||
} // namespace heir | ||
} // namespace mlir |
20 changes: 20 additions & 0 deletions
20
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ | ||
#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace mod_arith { | ||
|
||
#define GEN_PASS_DECL | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
} // namespace mod_arith | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ |
22 changes: 22 additions & 0 deletions
22
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ | ||
#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ | ||
|
||
include "lib/Utils/DRR/Utils.td" | ||
include "lib/Dialect/ModArith/IR/ModArithOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/Pass/PassBase.td" | ||
|
||
def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> { | ||
let summary = "Lower standard `arith` to `mod-arith`."; | ||
|
||
let description = [{ | ||
This pass lowers the `arith` dialect to their `mod-arith` equivalents. | ||
}]; | ||
|
||
let dependentDialects = [ | ||
"mlir::arith::ArithDialect", | ||
"mlir::heir::mod_arith::ModArithDialect", | ||
]; | ||
} | ||
|
||
#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "ArithToModArith", | ||
srcs = ["ArithToModArith.cpp"], | ||
hdrs = [ | ||
"ArithToModArith.h", | ||
], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@heir//lib/Utils/ConversionUtils", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ArithToModArith", | ||
], | ||
"ArithToModArith.h.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "ArithToModArith.td", | ||
deps = [ | ||
"@heir//lib/Dialect/ModArith/IR:ops_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:td_files", | ||
"@llvm-project//mlir:ArithOpsTdFiles", | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
25 changes: 25 additions & 0 deletions
25
lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
add_heir_pass(ArithToModArith PATTERNS) | ||
|
||
add_mlir_conversion_library(HEIRArithToModArith | ||
ArithToModArith.cpp | ||
|
||
DEPENDS | ||
HEIRArithToModArithIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
HEIRModArith | ||
|
||
LINK_LIBS PUBLIC | ||
|
||
LLVMSupport | ||
|
||
MLIRArithDialect | ||
MLIRDialect | ||
MLIRInferTypeOpInterface | ||
MLIRIR | ||
MLIRMemRefDialect | ||
MLIRPass | ||
MLIRSupport | ||
MLIRTransforms | ||
MLIRTransformUtils | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(Conversions) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Transforms", | ||
hdrs = ["Passes.h"], | ||
deps = [ | ||
":ConvertToMac", | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "ConvertToMac", | ||
srcs = ["ConvertToMac.cpp"], | ||
hdrs = ["ConvertToMac.h"], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ModArith", | ||
], | ||
"Passes.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"ModArithPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h" | ||
|
||
#include "lib/Dialect/ModArith/IR/ModArithOps.h" | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "mod-arith-mac" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace mod_arith { | ||
|
||
#define GEN_PASS_DEF_CONVERTTOMAC | ||
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" | ||
|
||
struct FindMac : public OpRewritePattern<mod_arith::AddOp> { | ||
FindMac(mlir::MLIRContext *context) | ||
: OpRewritePattern<mod_arith::AddOp>(context) {} | ||
|
||
LogicalResult matchAndRewrite(mod_arith::AddOp op, | ||
PatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto *parent = op.getLhs().getDefiningOp(); | ||
|
||
if (!parent || !isa<mod_arith::MulOp>(parent)) { | ||
return failure(); | ||
} | ||
|
||
auto mul_parent = cast<mod_arith::MulOp>(parent); | ||
|
||
auto result = | ||
b.create<MacOp>(mul_parent.getLhs(), mul_parent.getRhs(), op.getRhs()); | ||
|
||
rewriter.replaceOp(op, result); | ||
rewriter.eraseOp(mul_parent); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> { | ||
using ConvertToMacBase::ConvertToMacBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<FindMac>(context); | ||
|
||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace mod_arith | ||
} // namespace heir | ||
} // namespace mlir |
Oops, something went wrong.