Skip to content

Commit

Permalink
Arith to ModArith Conversion pass and Mac Transformer pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest committed Nov 28, 2024
1 parent 55c07f6 commit edf3902
Show file tree
Hide file tree
Showing 30 changed files with 551 additions and 2,241 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ repos:
rev: "v2.2.5"
hooks:
- id: codespell
args: ["-L", "crate"]
args: ["-L", "crate, fpt"]


# Changes tabs to spaces
Expand Down
120 changes: 120 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h"

#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Utils/ConversionUtils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "arith-to-mod-arith"

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DEF_ARITHTOMODARITH
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

static ModArithType convertArithType(Type type) {
auto modulus = type.getIntOrFloatBitWidth();
return ModArithType::get(type.getContext(),
mlir::IntegerAttr::get(type, modulus));
}

static Type convertArithLikeType(ShapedType type) {
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) {
return type.cloneWith(type.getShape(), convertArithType(arithType));
}
return type;
}

class ArithToModArithTypeConverter : public TypeConverter {
public:
ArithToModArithTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([](IntegerType type) -> ModArithType {
return convertArithType(type);
});
addConversion(
[](ShapedType type) -> Type { return convertArithLikeType(type); });
}
};

struct ConvertConstant : public OpConversionPattern<::mlir::arith::ConstantOp> {
ConvertConstant(mlir::MLIRContext *context)
: OpConversionPattern<::mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
::mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result = b.create<mod_arith::ConstantOp>(ModArithAttr::get(
convertArithType(op.getType()),
cast<IntegerAttr>(op.getValue()).getValue().getSExtValue()));

rewriter.replaceOp(op, result);
return success();
return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConvertBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithToModArith : impl::ArithToModArithBase<ArithToModArith> {
using ArithToModArithBase::ArithToModArithBase;

void runOnOperation() override;
};

void ArithToModArith::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
ArithToModArithTypeConverter typeConverter(context);

ConversionTarget target(*context);
target.addLegalDialect<ModArithDialect>();
target.addIllegalDialect<arith::ArithDialect>();

RewritePatternSet patterns(context);
patterns.add<ConvertConstant, ConvertBinOp<arith::AddIOp, AddOp>,
ConvertBinOp<arith::SubIOp, SubOp>,
ConvertBinOp<arith::MulIOp, MulOp>>(typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace mod_arith
} // namespace heir
} // namespace mlir
20 changes: 20 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DECL
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

} // namespace mod_arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
22 changes: 22 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_

include "lib/Utils/DRR/Utils.td"
include "lib/Dialect/ModArith/IR/ModArithOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Pass/PassBase.td"

def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> {
let summary = "Lower standard `arith` to `mod-arith`.";

let description = [{
This pass lowers the `arith` dialect to their `mod-arith` equivalents.
}];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::mod_arith::ModArithDialect",
];
}

#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
44 changes: 44 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ArithToModArith",
srcs = ["ArithToModArith.cpp"],
hdrs = [
"ArithToModArith.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Utils/ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithToModArith",
],
"ArithToModArith.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ArithToModArith.td",
deps = [
"@heir//lib/Dialect/ModArith/IR:ops_inc_gen",
"@heir//lib/Dialect/ModArith/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
25 changes: 25 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
add_heir_pass(ArithToModArith PATTERNS)

add_mlir_conversion_library(HEIRArithToModArith
ArithToModArith.cpp

DEPENDS
HEIRArithToModArithIncGen

LINK_LIBS PUBLIC
HEIRModArith

LINK_LIBS PUBLIC

LLVMSupport

MLIRArithDialect
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRSupport
MLIRTransforms
MLIRTransformUtils
)
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Conversions)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ namespace mod_arith {
#define GEN_PASS_DEF_MODARITHTOARITH
#include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h.inc"

IntegerType convertModArithType(ModArithType type) {
static IntegerType convertModArithType(ModArithType type) {
APInt modulus = type.getModulus().getValue();
return IntegerType::get(type.getContext(), modulus.getBitWidth());
}

Type convertModArithLikeType(ShapedType type) {
static Type convertModArithLikeType(ShapedType type) {
if (auto modArithType = llvm::dyn_cast<ModArithType>(type.getElementType())) {
return type.cloneWith(type.getShape(), convertModArithType(modArithType));
}
Expand All @@ -58,7 +58,7 @@ class ModArithToArithTypeConverter : public TypeConverter {
// needed to represent the result of mod_arith op as an integer
// before applying a remainder operation
template <typename Op>
TypedAttr modulusAttr(Op op, bool mul = false) {
static TypedAttr modulusAttr(Op op, bool mul = false) {
auto type = op.getResult().getType();
auto modArithType = getResultModArithType(op);
APInt modulus = modArithType.getModulus().getValue();
Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/ModArith/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":ConvertToMac",
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "ConvertToMac",
srcs = ["ConvertToMac.cpp"],
hdrs = ["ConvertToMac.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ModArith",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"ModArithPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
58 changes: 58 additions & 0 deletions lib/Dialect/ModArith/Transforms/ConvertToMac.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h"

#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "mod-arith-mac"

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DEF_CONVERTTOMAC
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc"

struct FindMac : public OpRewritePattern<mod_arith::AddOp> {
FindMac(mlir::MLIRContext *context)
: OpRewritePattern<mod_arith::AddOp>(context) {}

LogicalResult matchAndRewrite(mod_arith::AddOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto *parent = op.getLhs().getDefiningOp();

if (!parent || !isa<mod_arith::MulOp>(parent)) {
return failure();
}

auto mul_parent = cast<mod_arith::MulOp>(parent);

auto result =
b.create<MacOp>(mul_parent.getLhs(), mul_parent.getRhs(), op.getRhs());

rewriter.replaceOp(op, result);
rewriter.eraseOp(mul_parent);

return success();
}
};

struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> {
using ConvertToMacBase::ConvertToMacBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<FindMac>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace mod_arith
} // namespace heir
} // namespace mlir
Loading

0 comments on commit edf3902

Please sign in to comment.