Skip to content

Commit

Permalink
Arith to ModArith Conversion pass and Mac Transformer pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest committed Dec 9, 2024
1 parent 12efce7 commit bc8feac
Show file tree
Hide file tree
Showing 25 changed files with 587 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ repos:
rev: "v2.2.5"
hooks:
- id: codespell
args: ["-L", "crate"]
args: ["-L", "crate, fpt"]


# Changes tabs to spaces
Expand Down
122 changes: 122 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h"

#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Utils/ConversionUtils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "arith-to-mod-arith"

namespace mlir {
namespace heir {
namespace arith {

#define GEN_PASS_DEF_ARITHTOMODARITH
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

static mod_arith::ModArithType convertArithType(Type type) {
auto modulusBitSize = type.getIntOrFloatBitWidth();
auto modulus = (int)(1 << (modulusBitSize - 1)) - 1;
return mod_arith::ModArithType::get(type.getContext(),
mlir::IntegerAttr::get(type, modulus));
}

static Type convertArithLikeType(ShapedType type) {
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) {
return type.cloneWith(type.getShape(), convertArithType(arithType));
}
return type;
}

class ArithToModArithTypeConverter : public TypeConverter {
public:
ArithToModArithTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([](IntegerType type) -> mod_arith::ModArithType {
return convertArithType(type);
});
addConversion(
[](ShapedType type) -> Type { return convertArithLikeType(type); });
}
};

struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> {
ConvertConstant(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
::mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result = b.create<mod_arith::ConstantOp>(mod_arith::ModArithAttr::get(
convertArithType(op.getType()),
cast<IntegerAttr>(op.getValue()).getValue().getSExtValue()));

rewriter.replaceOp(op, result);
return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConvertBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithToModArith : impl::ArithToModArithBase<ArithToModArith> {
using ArithToModArithBase::ArithToModArithBase;

void runOnOperation() override;
};

void ArithToModArith::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
ArithToModArithTypeConverter typeConverter(context);

ConversionTarget target(*context);
target.addLegalDialect<mod_arith::ModArithDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();

RewritePatternSet patterns(context);
patterns
.add<ConvertConstant, ConvertBinOp<mlir::arith::AddIOp, mod_arith::AddOp>,
ConvertBinOp<mlir::arith::SubIOp, mod_arith::SubOp>,
ConvertBinOp<mlir::arith::MulIOp, mod_arith::MulOp>>(typeConverter,
context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace arith
} // namespace heir
} // namespace mlir
20 changes: 20 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith {

#define GEN_PASS_DECL
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

} // namespace arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
29 changes: 29 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_

include "lib/Utils/DRR/Utils.td"
include "lib/Dialect/ModArith/IR/ModArithOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Pass/PassBase.td"

def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> {
let summary = "Lower standard `arith` to `mod-arith`.";

let description = [{
This pass lowers the `arith` dialect to their `mod-arith` equivalents.

The arith-to-mod-arith pass is required to lower a neural network TOSA
model to a CGGI backend. This pass will transform the operations to the
mod-arith dialect, where the find-mac pass can be used to convert
consecutive multiply addition operations into a single operation. In a
later pass, these large precision MAC operations (typically
64 or 32-bit) will be lowered into small precision (8 or 4b) operations
that can be mapped to CGGI operations. }];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::mod_arith::ModArithDialect",
];
}

#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
44 changes: 44 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ArithToModArith",
srcs = ["ArithToModArith.cpp"],
hdrs = [
"ArithToModArith.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Utils/ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithToModArith",
],
"ArithToModArith.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ArithToModArith.td",
deps = [
"@heir//lib/Dialect/ModArith/IR:ops_inc_gen",
"@heir//lib/Dialect/ModArith/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
25 changes: 25 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
add_heir_pass(ArithToModArith PATTERNS)

add_mlir_conversion_library(HEIRArithToModArith
ArithToModArith.cpp

DEPENDS
HEIRArithToModArithIncGen

LINK_LIBS PUBLIC
HEIRModArith

LINK_LIBS PUBLIC

LLVMSupport

MLIRArithDialect
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRSupport
MLIRTransforms
MLIRTransformUtils
)
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Conversions)
add_subdirectory(IR)
add_subdirectory(Transforms)
53 changes: 53 additions & 0 deletions lib/Dialect/ModArith/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":ConvertToMac",
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "ConvertToMac",
srcs = ["ConvertToMac.cpp"],
hdrs = ["ConvertToMac.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ModArith",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"ModArithPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
62 changes: 62 additions & 0 deletions lib/Dialect/ModArith/Transforms/ConvertToMac.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h"

#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "mod-arith-mac"

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DEF_CONVERTTOMAC
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc"

struct FindMac : public OpRewritePattern<mod_arith::AddOp> {
FindMac(mlir::MLIRContext *context)
: OpRewritePattern<mod_arith::AddOp>(context) {}

LogicalResult matchAndRewrite(mod_arith::AddOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto parent = op.getLhs().getDefiningOp<mod_arith::MulOp>();
if (!parent) {
auto parentRhs = op.getRhs().getDefiningOp<mod_arith::MulOp>();
if (!parentRhs) {
return failure();
}
parent = parentRhs;
}

auto result =
b.create<MacOp>(parent.getLhs(), parent.getRhs(), op.getRhs());

rewriter.replaceOp(op, result);

if (parent.use_empty()) {
rewriter.eraseOp(parent);
}

return success();
}
};

struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> {
using ConvertToMacBase::ConvertToMacBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<FindMac>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace mod_arith
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/ModArith/Transforms/ConvertToMac.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_
#define LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DECL_CONVERTTOMAC
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc"

} // namespace mod_arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_
Loading

0 comments on commit bc8feac

Please sign in to comment.