-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Arith to ModArith Conversion pass and Mac Transformer pass
- Loading branch information
Wouter Legiest
committed
Dec 5, 2024
1 parent
4b63d65
commit baca68e
Showing
25 changed files
with
583 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" | ||
|
||
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h" | ||
#include "lib/Dialect/ModArith/IR/ModArithOps.h" | ||
#include "lib/Dialect/ModArith/IR/ModArithTypes.h" | ||
#include "lib/Utils/ConversionUtils/ConversionUtils.h" | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "arith-to-mod-arith" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace arith { | ||
|
||
#define GEN_PASS_DEF_ARITHTOMODARITH | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
static mod_arith::ModArithType convertArithType(Type type) { | ||
auto modulus = type.getIntOrFloatBitWidth(); | ||
return mod_arith::ModArithType::get(type.getContext(), | ||
mlir::IntegerAttr::get(type, modulus)); | ||
} | ||
|
||
static Type convertArithLikeType(ShapedType type) { | ||
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) { | ||
return type.cloneWith(type.getShape(), convertArithType(arithType)); | ||
} | ||
return type; | ||
} | ||
|
||
class ArithToModArithTypeConverter : public TypeConverter { | ||
public: | ||
ArithToModArithTypeConverter(MLIRContext *ctx) { | ||
addConversion([](Type type) { return type; }); | ||
addConversion([](IntegerType type) -> mod_arith::ModArithType { | ||
return convertArithType(type); | ||
}); | ||
addConversion( | ||
[](ShapedType type) -> Type { return convertArithLikeType(type); }); | ||
} | ||
}; | ||
|
||
struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> { | ||
ConvertConstant(mlir::MLIRContext *context) | ||
: OpConversionPattern<mlir::arith::ConstantOp>(context) {} | ||
|
||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
::mlir::arith::ConstantOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto result = b.create<mod_arith::ConstantOp>(mod_arith::ModArithAttr::get( | ||
convertArithType(op.getType()), | ||
cast<IntegerAttr>(op.getValue()).getValue().getSExtValue())); | ||
|
||
rewriter.replaceOp(op, result); | ||
return success(); | ||
return success(); | ||
} | ||
}; | ||
|
||
template <typename SourceArithOp, typename TargetModArithOp> | ||
struct ConvertBinOp : public OpConversionPattern<SourceArithOp> { | ||
ConvertBinOp(mlir::MLIRContext *context) | ||
: OpConversionPattern<SourceArithOp>(context) {} | ||
|
||
using OpConversionPattern<SourceArithOp>::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
SourceArithOp op, typename SourceArithOp::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto result = | ||
b.create<TargetModArithOp>(adaptor.getLhs(), adaptor.getRhs()); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ArithToModArith : impl::ArithToModArithBase<ArithToModArith> { | ||
using ArithToModArithBase::ArithToModArithBase; | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void ArithToModArith::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp module = getOperation(); | ||
ArithToModArithTypeConverter typeConverter(context); | ||
|
||
ConversionTarget target(*context); | ||
target.addLegalDialect<mod_arith::ModArithDialect>(); | ||
target.addIllegalDialect<mlir::arith::ArithDialect>(); | ||
|
||
RewritePatternSet patterns(context); | ||
patterns | ||
.add<ConvertConstant, ConvertBinOp<mlir::arith::AddIOp, mod_arith::AddOp>, | ||
ConvertBinOp<mlir::arith::SubIOp, mod_arith::SubOp>, | ||
ConvertBinOp<mlir::arith::MulIOp, mod_arith::MulOp>>(typeConverter, | ||
context); | ||
|
||
addStructuralConversionPatterns(typeConverter, patterns, target); | ||
|
||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace arith | ||
} // namespace heir | ||
} // namespace mlir |
20 changes: 20 additions & 0 deletions
20
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ | ||
#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace arith { | ||
|
||
#define GEN_PASS_DECL | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" | ||
|
||
} // namespace arith | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ |
29 changes: 29 additions & 0 deletions
29
lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ | ||
#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ | ||
|
||
include "lib/Utils/DRR/Utils.td" | ||
include "lib/Dialect/ModArith/IR/ModArithOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/Pass/PassBase.td" | ||
|
||
def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> { | ||
let summary = "Lower standard `arith` to `mod-arith`."; | ||
|
||
let description = [{ | ||
This pass lowers the `arith` dialect to their `mod-arith` equivalents. | ||
|
||
The arith-to-mod-arith pass is required to lower a neural network TOSA | ||
model to a CGGI backend. This pass will transform the operations to the | ||
mod-arith dialect, where the find-mac pass can be used to convert | ||
consecutive multiply addition operations into a single operation. In a | ||
later pass, these large precision MAC operations (typically | ||
64 or 32-bit) will be lowered into small precision (8 or 4b) operations | ||
that can be mapped to CGGI operations. }]; | ||
|
||
let dependentDialects = [ | ||
"mlir::arith::ArithDialect", | ||
"mlir::heir::mod_arith::ModArithDialect", | ||
]; | ||
} | ||
|
||
#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "ArithToModArith", | ||
srcs = ["ArithToModArith.cpp"], | ||
hdrs = [ | ||
"ArithToModArith.h", | ||
], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@heir//lib/Utils/ConversionUtils", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ArithToModArith", | ||
], | ||
"ArithToModArith.h.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "ArithToModArith.td", | ||
deps = [ | ||
"@heir//lib/Dialect/ModArith/IR:ops_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:td_files", | ||
"@llvm-project//mlir:ArithOpsTdFiles", | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
25 changes: 25 additions & 0 deletions
25
lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
add_heir_pass(ArithToModArith PATTERNS) | ||
|
||
add_mlir_conversion_library(HEIRArithToModArith | ||
ArithToModArith.cpp | ||
|
||
DEPENDS | ||
HEIRArithToModArithIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
HEIRModArith | ||
|
||
LINK_LIBS PUBLIC | ||
|
||
LLVMSupport | ||
|
||
MLIRArithDialect | ||
MLIRDialect | ||
MLIRInferTypeOpInterface | ||
MLIRIR | ||
MLIRMemRefDialect | ||
MLIRPass | ||
MLIRSupport | ||
MLIRTransforms | ||
MLIRTransformUtils | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(Conversions) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Transforms", | ||
hdrs = ["Passes.h"], | ||
deps = [ | ||
":ConvertToMac", | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "ConvertToMac", | ||
srcs = ["ConvertToMac.cpp"], | ||
hdrs = ["ConvertToMac.h"], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ModArith/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ModArith", | ||
], | ||
"Passes.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"ModArithPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h" | ||
|
||
#include "lib/Dialect/ModArith/IR/ModArithOps.h" | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "mod-arith-mac" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace mod_arith { | ||
|
||
#define GEN_PASS_DEF_CONVERTTOMAC | ||
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" | ||
|
||
struct FindMac : public OpRewritePattern<mod_arith::AddOp> { | ||
FindMac(mlir::MLIRContext *context) | ||
: OpRewritePattern<mod_arith::AddOp>(context) {} | ||
|
||
LogicalResult matchAndRewrite(mod_arith::AddOp op, | ||
PatternRewriter &rewriter) const override { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
|
||
auto *parent = op.getLhs().getDefiningOp(); | ||
|
||
if (!parent || !isa<mod_arith::MulOp>(parent)) { | ||
return failure(); | ||
} | ||
|
||
auto mul_parent = cast<mod_arith::MulOp>(parent); | ||
|
||
auto result = | ||
b.create<MacOp>(mul_parent.getLhs(), mul_parent.getRhs(), op.getRhs()); | ||
|
||
rewriter.replaceOp(op, result); | ||
rewriter.eraseOp(mul_parent); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct ConvertToMac : impl::ConvertToMacBase<ConvertToMac> { | ||
using ConvertToMacBase::ConvertToMacBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<FindMac>(context); | ||
|
||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace mod_arith | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ | ||
#define LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace mod_arith { | ||
|
||
#define GEN_PASS_DECL_CONVERTTOMAC | ||
#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" | ||
|
||
} // namespace mod_arith | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ |
Oops, something went wrong.