Skip to content

Commit

Permalink
Arith to ModArith Conversion pass and Mac Transformer pass
Browse files Browse the repository at this point in the history
Wouter Legiest committed Dec 12, 2024
1 parent 854a878 commit 836c764
Showing 25 changed files with 595 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ repos:
rev: "v2.2.5"
hooks:
- id: codespell
args: ["-L", "crate"]
args: ["-L", "crate, fpt"]


# Changes tabs to spaces
124 changes: 124 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h"

#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Utils/ConversionUtils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "arith-to-mod-arith"

namespace mlir {
namespace heir {
namespace arith {

#define GEN_PASS_DEF_ARITHTOMODARITH
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

static mod_arith::ModArithType convertArithType(Type type) {
auto modulusBitSize = (long)type.getIntOrFloatBitWidth();
auto modulus = (1L << (modulusBitSize - 1L));
auto newType = mlir::IntegerType::get(type.getContext(), modulusBitSize + 1);

return mod_arith::ModArithType::get(newType.getContext(),
mlir::IntegerAttr::get(newType, modulus));
}

static Type convertArithLikeType(ShapedType type) {
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) {
return type.cloneWith(type.getShape(), convertArithType(arithType));
}
return type;
}

class ArithToModArithTypeConverter : public TypeConverter {
public:
ArithToModArithTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([](IntegerType type) -> mod_arith::ModArithType {
return convertArithType(type);
});
addConversion(
[](ShapedType type) -> Type { return convertArithLikeType(type); });
}
};

struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> {
ConvertConstant(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
::mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result = b.create<mod_arith::ConstantOp>(mod_arith::ModArithAttr::get(
convertArithType(op.getType()),
cast<IntegerAttr>(op.getValue()).getValue().getSExtValue()));

rewriter.replaceOp(op, result);
return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> {
ConvertBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto result =
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithToModArith : impl::ArithToModArithBase<ArithToModArith> {
using ArithToModArithBase::ArithToModArithBase;

void runOnOperation() override;
};

void ArithToModArith::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
ArithToModArithTypeConverter typeConverter(context);

ConversionTarget target(*context);
target.addLegalDialect<mod_arith::ModArithDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();

RewritePatternSet patterns(context);
patterns
.add<ConvertConstant, ConvertBinOp<mlir::arith::AddIOp, mod_arith::AddOp>,
ConvertBinOp<mlir::arith::SubIOp, mod_arith::SubOp>,
ConvertBinOp<mlir::arith::MulIOp, mod_arith::MulOp>>(typeConverter,
context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace arith
} // namespace heir
} // namespace mlir
20 changes: 20 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith {

#define GEN_PASS_DECL
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc"

} // namespace arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_
29 changes: 29 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_

include "lib/Utils/DRR/Utils.td"
include "lib/Dialect/ModArith/IR/ModArithOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Pass/PassBase.td"

def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> {
let summary = "Lower standard `arith` to `mod-arith`.";

let description = [{
This pass lowers the `arith` dialect to their `mod-arith` equivalents.

The arith-to-mod-arith pass is required to lower a neural network TOSA
model to a CGGI backend. This pass will transform the operations to the
mod-arith dialect, where the find-mac pass can be used to convert
consecutive multiply addition operations into a single operation. In a
later pass, these large precision MAC operations (typically
64 or 32-bit) will be lowered into small precision (8 or 4b) operations
that can be mapped to CGGI operations. }];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::mod_arith::ModArithDialect",
];
}

#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_
44 changes: 44 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ArithToModArith",
srcs = ["ArithToModArith.cpp"],
hdrs = [
"ArithToModArith.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Utils/ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithToModArith",
],
"ArithToModArith.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ArithToModArith.td",
deps = [
"@heir//lib/Dialect/ModArith/IR:ops_inc_gen",
"@heir//lib/Dialect/ModArith/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
25 changes: 25 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
add_heir_pass(ArithToModArith PATTERNS)

add_mlir_conversion_library(HEIRArithToModArith
ArithToModArith.cpp

DEPENDS
HEIRArithToModArithIncGen

LINK_LIBS PUBLIC
HEIRModArith

LINK_LIBS PUBLIC

LLVMSupport

MLIRArithDialect
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRSupport
MLIRTransforms
MLIRTransformUtils
)
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Conversions)
add_subdirectory(IR)
add_subdirectory(Transforms)
53 changes: 53 additions & 0 deletions lib/Dialect/ModArith/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":ConvertToMac",
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "ConvertToMac",
srcs = ["ConvertToMac.cpp"],
hdrs = ["ConvertToMac.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ModArith",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"ModArithPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
66 changes: 66 additions & 0 deletions lib/Dialect/ModArith/Transforms/ConvertToMac.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h"

#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "mod-arith-mac"

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DEF_CONVERTTOMAC
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc"

struct FindMac : public OpRewritePattern<mod_arith::AddOp> {
FindMac(mlir::MLIRContext *context)
: OpRewritePattern<mod_arith::AddOp>(context) {}

LogicalResult matchAndRewrite(mod_arith::AddOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Assume that we have a form a x b + rhs
auto parent = op.getLhs().getDefiningOp<mod_arith::MulOp>();
auto addOperand = op.getRhs();

if (!parent) {
auto parentRhs = op.getRhs().getDefiningOp<mod_arith::MulOp>();
if (!parentRhs) {
return failure();
}
// Find we have a form of lhs + a x b
parent = parentRhs;
addOperand = op.getLhs();
}

auto result = b.create<MacOp>(parent.getLhs(), parent.getRhs(), addOperand);

rewriter.replaceOp(op, result);

if (parent.use_empty()) {
rewriter.eraseOp(parent);
}

return success();
}
};

struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> {
using ConvertToMacBase::ConvertToMacBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<FindMac>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace mod_arith
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/ModArith/Transforms/ConvertToMac.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_
#define LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace mod_arith {

#define GEN_PASS_DECL_CONVERTTOMAC
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc"

} // namespace mod_arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_
Loading

0 comments on commit 836c764

Please sign in to comment.