From 836c764fc4281e6baf6437392a4c3801139a53d6 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 19 Nov 2024 23:26:13 +0000 Subject: [PATCH] Arith to ModArith Conversion pass and Mac Transformer pass --- .pre-commit-config.yaml | 2 +- .../ArithToModArith/ArithToModArith.cpp | 124 ++++++++++++++++++ .../ArithToModArith/ArithToModArith.h | 20 +++ .../ArithToModArith/ArithToModArith.td | 29 ++++ .../Arith/Conversions/ArithToModArith/BUILD | 44 +++++++ .../ArithToModArith/CMakeLists.txt | 25 ++++ lib/Dialect/ModArith/CMakeLists.txt | 1 + lib/Dialect/ModArith/Transforms/BUILD | 53 ++++++++ .../ModArith/Transforms/ConvertToMac.cpp | 66 ++++++++++ .../ModArith/Transforms/ConvertToMac.h | 17 +++ lib/Dialect/ModArith/Transforms/Passes.h | 18 +++ lib/Dialect/ModArith/Transforms/Passes.td | 16 +++ lib/Dialect/Utils.h | 3 +- .../ArithmeticPipelineRegistration.cpp | 52 +++++++- .../ArithmeticPipelineRegistration.h | 4 + lib/Pipelines/BUILD | 7 + lib/Pipelines/BooleanPipelineRegistration.cpp | 6 +- lib/Utils/ConversionUtils/ConversionUtils.cpp | 2 +- lib/Utils/ConversionUtils/ConversionUtils.h | 3 +- .../Arith/Conversions/ArithToModArith/BUILD | 10 ++ .../ArithToModArith/arith-to-mod-arith.mlir | 55 ++++++++ tests/Dialect/ModArith/Transforms/BUILD | 10 ++ .../Dialect/ModArith/Transforms/find_mac.mlir | 23 ++++ tools/BUILD | 3 + tools/heir-opt.cpp | 14 +- 25 files changed, 595 insertions(+), 12 deletions(-) create mode 100644 lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp create mode 100644 lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h create mode 100644 lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td create mode 100644 lib/Dialect/Arith/Conversions/ArithToModArith/BUILD create mode 100644 lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt create mode 100644 lib/Dialect/ModArith/Transforms/BUILD create mode 100644 lib/Dialect/ModArith/Transforms/ConvertToMac.cpp create mode 100644 lib/Dialect/ModArith/Transforms/ConvertToMac.h create mode 100644 lib/Dialect/ModArith/Transforms/Passes.h create mode 100644 lib/Dialect/ModArith/Transforms/Passes.td create mode 100644 tests/Dialect/Arith/Conversions/ArithToModArith/BUILD create mode 100644 tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir create mode 100644 tests/Dialect/ModArith/Transforms/BUILD create mode 100644 tests/Dialect/ModArith/Transforms/find_mac.mlir diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df021c6ad..79e932234 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: "v2.2.5" hooks: - id: codespell - args: ["-L", "crate"] + args: ["-L", "crate, fpt"] # Changes tabs to spaces diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp new file mode 100644 index 000000000..e60b37349 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp @@ -0,0 +1,124 @@ +#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" + +#include "lib/Dialect/ModArith/IR/ModArithAttributes.h" +#include "lib/Dialect/ModArith/IR/ModArithOps.h" +#include "lib/Dialect/ModArith/IR/ModArithTypes.h" +#include "lib/Utils/ConversionUtils/ConversionUtils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "arith-to-mod-arith" + +namespace mlir { +namespace heir { +namespace arith { + +#define GEN_PASS_DEF_ARITHTOMODARITH +#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" + +static mod_arith::ModArithType convertArithType(Type type) { + auto modulusBitSize = (long)type.getIntOrFloatBitWidth(); + auto modulus = (1L << (modulusBitSize - 1L)); + auto newType = mlir::IntegerType::get(type.getContext(), modulusBitSize + 1); + + return mod_arith::ModArithType::get(newType.getContext(), + mlir::IntegerAttr::get(newType, modulus)); +} + +static Type convertArithLikeType(ShapedType type) { + if (auto arithType = llvm::dyn_cast(type.getElementType())) { + return type.cloneWith(type.getShape(), convertArithType(arithType)); + } + return type; +} + +class ArithToModArithTypeConverter : public TypeConverter { + public: + ArithToModArithTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([](IntegerType type) -> mod_arith::ModArithType { + return convertArithType(type); + }); + addConversion( + [](ShapedType type) -> Type { return convertArithLikeType(type); }); + } +}; + +struct ConvertConstant : public OpConversionPattern { + ConvertConstant(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ::mlir::arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto result = b.create(mod_arith::ModArithAttr::get( + convertArithType(op.getType()), + cast(op.getValue()).getValue().getSExtValue())); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +struct ConvertBinOp : public OpConversionPattern { + ConvertBinOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SourceArithOp op, typename SourceArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto result = + b.create(adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithToModArith : impl::ArithToModArithBase { + using ArithToModArithBase::ArithToModArithBase; + + void runOnOperation() override; +}; + +void ArithToModArith::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + ArithToModArithTypeConverter typeConverter(context); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(context); + patterns + .add, + ConvertBinOp, + ConvertBinOp>(typeConverter, + context); + + addStructuralConversionPatterns(typeConverter, patterns, target); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace arith +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h new file mode 100644 index 000000000..921c3ffd2 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h @@ -0,0 +1,20 @@ +#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ +#define LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace arith { + +#define GEN_PASS_DECL +#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h.inc" + +} // namespace arith +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_MODARITH_TRANSFORMS_ARITHTOMODARITH_H_ diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td new file mode 100644 index 000000000..c5ec8d795 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.td @@ -0,0 +1,29 @@ +#ifndef LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ +#define LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ + +include "lib/Utils/DRR/Utils.td" +include "lib/Dialect/ModArith/IR/ModArithOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Pass/PassBase.td" + +def ArithToModArith : Pass<"arith-to-mod-arith", "ModuleOp"> { + let summary = "Lower standard `arith` to `mod-arith`."; + + let description = [{ + This pass lowers the `arith` dialect to their `mod-arith` equivalents. + + The arith-to-mod-arith pass is required to lower a neural network TOSA + model to a CGGI backend. This pass will transform the operations to the + mod-arith dialect, where the find-mac pass can be used to convert + consecutive multiply addition operations into a single operation. In a + later pass, these large precision MAC operations (typically + 64 or 32-bit) will be lowered into small precision (8 or 4b) operations + that can be mapped to CGGI operations. }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::heir::mod_arith::ModArithDialect", + ]; +} + +#endif // LIB_DIALECT_MODARITH_CONVERSIONS_ARITHTOMODARITH_ARITHTOMODARITH_TD_ diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/BUILD b/lib/Dialect/Arith/Conversions/ArithToModArith/BUILD new file mode 100644 index 000000000..c8cba2a39 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/BUILD @@ -0,0 +1,44 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ArithToModArith", + srcs = ["ArithToModArith.cpp"], + hdrs = [ + "ArithToModArith.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/ModArith/IR:Dialect", + "@heir//lib/Utils/ConversionUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ArithToModArith", + ], + "ArithToModArith.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ArithToModArith.td", + deps = [ + "@heir//lib/Dialect/ModArith/IR:ops_inc_gen", + "@heir//lib/Dialect/ModArith/IR:td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt b/lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt new file mode 100644 index 000000000..2482b3e3a --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/CMakeLists.txt @@ -0,0 +1,25 @@ +add_heir_pass(ArithToModArith PATTERNS) + +add_mlir_conversion_library(HEIRArithToModArith + ArithToModArith.cpp + + DEPENDS + HEIRArithToModArithIncGen + + LINK_LIBS PUBLIC + HEIRModArith + + LINK_LIBS PUBLIC + + LLVMSupport + + MLIRArithDialect + MLIRDialect + MLIRInferTypeOpInterface + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRSupport + MLIRTransforms + MLIRTransformUtils +) diff --git a/lib/Dialect/ModArith/CMakeLists.txt b/lib/Dialect/ModArith/CMakeLists.txt index 626f79e80..74d0af5b3 100644 --- a/lib/Dialect/ModArith/CMakeLists.txt +++ b/lib/Dialect/ModArith/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Conversions) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/ModArith/Transforms/BUILD b/lib/Dialect/ModArith/Transforms/BUILD new file mode 100644 index 000000000..2123c86d1 --- /dev/null +++ b/lib/Dialect/ModArith/Transforms/BUILD @@ -0,0 +1,53 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Transforms", + hdrs = ["Passes.h"], + deps = [ + ":ConvertToMac", + ":pass_inc_gen", + "@heir//lib/Dialect/ModArith/IR:Dialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "ConvertToMac", + srcs = ["ConvertToMac.cpp"], + hdrs = ["ConvertToMac.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/ModArith/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ModArith", + ], + "Passes.h.inc", + ), + ( + ["-gen-pass-doc"], + "ModArithPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Dialect/ModArith/Transforms/ConvertToMac.cpp b/lib/Dialect/ModArith/Transforms/ConvertToMac.cpp new file mode 100644 index 000000000..f1c304dff --- /dev/null +++ b/lib/Dialect/ModArith/Transforms/ConvertToMac.cpp @@ -0,0 +1,66 @@ +#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h" + +#include "lib/Dialect/ModArith/IR/ModArithOps.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "mod-arith-mac" + +namespace mlir { +namespace heir { +namespace mod_arith { + +#define GEN_PASS_DEF_CONVERTTOMAC +#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" + +struct FindMac : public OpRewritePattern { + FindMac(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mod_arith::AddOp op, + PatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // Assume that we have a form a x b + rhs + auto parent = op.getLhs().getDefiningOp(); + auto addOperand = op.getRhs(); + + if (!parent) { + auto parentRhs = op.getRhs().getDefiningOp(); + if (!parentRhs) { + return failure(); + } + // Find we have a form of lhs + a x b + parent = parentRhs; + addOperand = op.getLhs(); + } + + auto result = b.create(parent.getLhs(), parent.getRhs(), addOperand); + + rewriter.replaceOp(op, result); + + if (parent.use_empty()) { + rewriter.eraseOp(parent); + } + + return success(); + } +}; + +struct ConvertToMac : impl::ConvertToMacBase { + using ConvertToMacBase::ConvertToMacBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + patterns.add(context); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace mod_arith +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/ModArith/Transforms/ConvertToMac.h b/lib/Dialect/ModArith/Transforms/ConvertToMac.h new file mode 100644 index 000000000..693c4792e --- /dev/null +++ b/lib/Dialect/ModArith/Transforms/ConvertToMac.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ +#define LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace mod_arith { + +#define GEN_PASS_DECL_CONVERTTOMAC +#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" + +} // namespace mod_arith +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_MODARITH_TRANSFORMS_CONVERTTOMAC_H_ diff --git a/lib/Dialect/ModArith/Transforms/Passes.h b/lib/Dialect/ModArith/Transforms/Passes.h new file mode 100644 index 000000000..a7ad4a982 --- /dev/null +++ b/lib/Dialect/ModArith/Transforms/Passes.h @@ -0,0 +1,18 @@ +#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_H_ +#define LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_H_ + +#include "lib/Dialect/ModArith/IR/ModArithDialect.h" +#include "lib/Dialect/ModArith/Transforms/ConvertToMac.h" + +namespace mlir { +namespace heir { +namespace mod_arith { + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/ModArith/Transforms/Passes.h.inc" + +} // namespace mod_arith +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/ModArith/Transforms/Passes.td b/lib/Dialect/ModArith/Transforms/Passes.td new file mode 100644 index 000000000..01fdc8e9c --- /dev/null +++ b/lib/Dialect/ModArith/Transforms/Passes.td @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_TD_ +#define LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def ConvertToMac : Pass<"mod-arith-to-mac"> { + let summary = "Finds consecutive ModArith mul and add operations and converts them to a Mac operation"; + let description = [{ + Walks over the programs to find Add operations, it checks if the any operands + originates from a mul operation. If so, it converts the Add operation to a + Mac operation and removes the mul operation. + }]; + let dependentDialects = ["mlir::heir::mod_arith::ModArithDialect"]; +} + +#endif // LIB_DIALECT_MODARITH_TRANSFORMS_PASSES_TD_ diff --git a/lib/Dialect/Utils.h b/lib/Dialect/Utils.h index 8c06765b4..72623fcde 100644 --- a/lib/Dialect/Utils.h +++ b/lib/Dialect/Utils.h @@ -28,7 +28,8 @@ FailureOr get1DExtractionIndex(Op op) { // -sccp before this pass to apply folding rules (use -sccp if you need to // fold constants through control flow). Value insertIndex = *insertIndices.begin(); - auto insertIndexConstOp = insertIndex.getDefiningOp(); + auto insertIndexConstOp = + insertIndex.getDefiningOp(); if (!insertIndexConstOp) return failure(); auto insertOffsetAttr = diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index f77b43019..29406ef3e 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -16,19 +16,28 @@ #include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h" #include "lib/Pipelines/PipelineRegistration.h" #include "lib/Transforms/ApplyFolders/ApplyFolders.h" +#include "lib/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h" +#include "lib/Transforms/MemrefToArith/MemrefToArith.h" #include "lib/Transforms/OperationBalancer/OperationBalancer.h" #include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h" #include "lib/Transforms/SecretInsertMgmt/Passes.h" #include "lib/Transforms/Secretize/Passes.h" -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "lib/Transforms/UnusedMemRef/UnusedMemRef.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/Passes.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +using mlir::func::FuncOp; + namespace mlir::heir { void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) { @@ -91,6 +100,41 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) { pm.addPass(createOperationBalancer()); } +void tosaToArithPipelineBuilder(OpPassManager &pm) { + // TOSA to linalg + ::mlir::heir::tosaToLinalg(pm); + + // Bufferize + ::mlir::heir::oneShotBufferize(pm); + + // Affine + pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); + pm.addNestedPass(memref::createExpandStridedMetadataPass()); + pm.addNestedPass(affine::createAffineExpandIndexOpsPass()); + pm.addNestedPass(memref::createExpandOpsPass()); + pm.addPass(createExpandCopyPass()); + pm.addNestedPass(affine::createSimplifyAffineStructuresPass()); + pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); + pm.addPass(memref::createFoldMemRefAliasOpsPass()); + + // Affine loop optimizations + pm.addNestedPass( + affine::createLoopFusionPass(0, 0, true, affine::FusionMode::Greedy)); + pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); + pm.addPass(createForwardStoreToLoad()); + pm.addPass(affine::createAffineParallelizePass()); + pm.addPass(createFullLoopUnroll()); + pm.addPass(createForwardStoreToLoad()); + pm.addNestedPass(createRemoveUnusedMemRef()); + + // Cleanup + pm.addPass(createMemrefGlobalReplacePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createSCCPPass()); + pm.addPass(createCSEPass()); + pm.addPass(createSymbolDCEPass()); +} + void mlirToRLWEPipeline(OpPassManager &pm, const MlirToRLWEPipelineOptions &options, const RLWEScheme scheme) { @@ -181,4 +225,10 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) { }; } +void registerTosaToArithPipeline() { + PassPipelineRegistration<>( + "tosa-to-arith", "Arithmetic modules to arith tfhe-rs pipeline.", + [](OpPassManager &pm) { tosaToArithPipelineBuilder(pm); }); +} + } // namespace mlir::heir diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index e1b4008ab..dc6b33b13 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -48,6 +48,8 @@ struct MlirToRLWEPipelineOptions using RLWEPipelineBuilder = std::function; +void tosaToArithPipelineBuilder(OpPassManager &pm); + void mlirToRLWEPipeline(OpPassManager &pm, const MlirToRLWEPipelineOptions &options, RLWEScheme scheme); @@ -58,6 +60,8 @@ RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme); RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(RLWEScheme scheme); +void registerTosaToArithPipeline(); + } // namespace mlir::heir #endif // LIB_PIPELINES_ARITHMETICPIPELINEREGISTRATION_H_ diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 07c7c5476..2905b7685 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -101,6 +101,7 @@ cc_library( "@heir//lib/Dialect/TensorExt/Transforms:InsertRotate", "@heir//lib/Dialect/TensorExt/Transforms:RotateAndReduce", "@heir//lib/Transforms/ApplyFolders", + "@heir//lib/Transforms/ForwardStoreToLoad", "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/LinalgCanonicalizations", "@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration", @@ -108,7 +109,13 @@ cc_library( "@heir//lib/Transforms/OptimizeRelinearization", "@heir//lib/Transforms/SecretInsertMgmt", "@heir//lib/Transforms/Secretize", + "@heir//lib/Transforms/UnusedMemRef", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", diff --git a/lib/Pipelines/BooleanPipelineRegistration.cpp b/lib/Pipelines/BooleanPipelineRegistration.cpp index b9e26145d..02811b9fe 100644 --- a/lib/Pipelines/BooleanPipelineRegistration.cpp +++ b/lib/Pipelines/BooleanPipelineRegistration.cpp @@ -93,9 +93,9 @@ void tosaToCGGIPipelineBuilder(OpPassManager &pm, pm.addPass(createCanonicalizerPass()); // Booleanize and Yosys Optimize - pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, options.abcFast, - options.unrollFactor, /*useSubmodules=*/true, - abcBooleanGates ? Mode::Boolean : Mode::LUT)); + pm.addPass(createYosysOptimizer( + yosysFilesPath, abcPath, options.abcFast, options.unrollFactor, + /*useSubmodules=*/true, abcBooleanGates ? Mode::Boolean : Mode::LUT)); // Cleanup pm.addPass(mlir::createCSEPass()); diff --git a/lib/Utils/ConversionUtils/ConversionUtils.cpp b/lib/Utils/ConversionUtils/ConversionUtils.cpp index f4bc891c2..2a1670719 100644 --- a/lib/Utils/ConversionUtils/ConversionUtils.cpp +++ b/lib/Utils/ConversionUtils/ConversionUtils.cpp @@ -193,7 +193,7 @@ struct ConvertFromElements newShape.append(shape.begin(), shape.end()); // Create a dense constant for targetShape - auto shapeOp = rewriter.create( + auto shapeOp = rewriter.create( op.getLoc(), RankedTensorType::get(newShape.size(), rewriter.getIndexType()), rewriter.getIndexTensorAttr(newShape)); diff --git a/lib/Utils/ConversionUtils/ConversionUtils.h b/lib/Utils/ConversionUtils/ConversionUtils.h index 5e57ce554..9e40930bd 100644 --- a/lib/Utils/ConversionUtils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils/ConversionUtils.h @@ -331,7 +331,8 @@ class SecretGenericOpRotateConversion ConversionPatternRewriter &rewriter) const override { // Check that the offset is a constant. auto offset = inputs[1]; - auto constantOffset = dyn_cast(offset.getDefiningOp()); + auto constantOffset = + dyn_cast(offset.getDefiningOp()); if (!constantOffset) { op.emitError("expected constant offset for rotate"); } diff --git a/tests/Dialect/Arith/Conversions/ArithToModArith/BUILD b/tests/Dialect/Arith/Conversions/ArithToModArith/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/Dialect/Arith/Conversions/ArithToModArith/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir b/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir new file mode 100644 index 000000000..2d395bbdb --- /dev/null +++ b/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir @@ -0,0 +1,55 @@ +// RUN: heir-opt --arith-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope + +// CHECK-LABEL: @test_lower_add +// CHECK-SAME: (%[[LHS:.*]]: !Z2147483648_i33_, %[[RHS:.*]]: !Z2147483648_i33_) -> [[T:.*]] { +func.func @test_lower_add(%lhs : i32, %rhs : i32) -> i32 { + // CHECK: %[[ADD:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.addi %lhs, %rhs : i32 + return %res : i32 +} + +// CHECK-LABEL: @test_lower_add_vec +// CHECK-SAME: (%[[LHS:.*]]: tensor<4x!Z2147483648_i33_>, %[[RHS:.*]]: tensor<4x!Z2147483648_i33_>) -> [[T:.*]] { +func.func @test_lower_add_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { + // CHECK: %[[ADD:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.addi %lhs, %rhs : tensor<4xi32> + return %res : tensor<4xi32> +} + +// CHECK-LABEL: @test_lower_sub_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_sub_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { + // CHECK: %[[ADD:.*]] = mod_arith.sub %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.subi %lhs, %rhs : tensor<4xi32> + return %res : tensor<4xi32> +} + +// CHECK-LABEL: @test_lower_sub +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_sub(%lhs : i32, %rhs : i32) -> i32 { + // CHECK: %[[ADD:.*]] = mod_arith.sub %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.subi %lhs, %rhs : i32 + return %res : i32 +} + +// CHECK-LABEL: @test_lower_mul +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mul(%lhs : i32, %rhs : i32) -> i32 { + // CHECK: %[[ADD:.*]] = mod_arith.mul %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.muli %lhs, %rhs : i32 + return %res : i32 +} + +// CHECK-LABEL: @test_lower_mul_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mul_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { + // CHECK: %[[ADD:.*]] = mod_arith.mul %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.muli %lhs, %rhs : tensor<4xi32> + return %res : tensor<4xi32> +} diff --git a/tests/Dialect/ModArith/Transforms/BUILD b/tests/Dialect/ModArith/Transforms/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/Dialect/ModArith/Transforms/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/ModArith/Transforms/find_mac.mlir b/tests/Dialect/ModArith/Transforms/find_mac.mlir new file mode 100644 index 000000000..8cfeb7c70 --- /dev/null +++ b/tests/Dialect/ModArith/Transforms/find_mac.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --arith-to-mod-arith --mod-arith-to-mac %s | FileCheck %s --enable-var-scope + +// CHECK-LABEL: @double_mac +// CHECK-SAME: (%[[ARG:.*]]: !Z32768_i17_) -> [[T:.*]] { +func.func @double_mac(%arg0: i16) -> i16 { + %c1 = arith.constant 1: i16 + %c2 = arith.constant 2 : i16 + %c3 = arith.constant 3: i16 + + // CHECK: %[[MAC1:.*]] = mod_arith.mac %[[ARG]], %{{.*}}, %{{.*}} : [[T]] + // CHECK: %[[ADD:.*]] = mod_arith.add %[[MAC1]], %{{.*}} : [[T]] + // CHECK: %[[MAC2:.*]] = mod_arith.mac %[[ADD]], %{{.*}}, %{{.*}} : [[T]] + // CHECK: return %[[MAC2]] : [[T]] + + %3 = arith.muli %arg0, %c2 : i16 + %4 = arith.addi %3, %c1 : i16 + %7 = arith.addi %4, %c3 : i16 + + %5 = arith.muli %7, %c3 : i16 + %6 = arith.addi %c2, %5 : i16 + + return %6 : i16 +} diff --git a/tools/BUILD b/tools/BUILD index 11c2bc3c4..715a78980 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -32,6 +32,7 @@ cc_binary( }), includes = ["include"], deps = [ + "@heir//lib/Dialect/Arith/Conversions/ArithToModArith", "@heir//lib/Dialect/BGV/Conversions/BGVToLWE", "@heir//lib/Dialect/BGV/Conversions/BGVToLattigo", "@heir//lib/Dialect/BGV/Conversions/BGVToOpenfhe", @@ -55,6 +56,8 @@ cc_binary( "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/ModArith/Conversions/ModArithToArith", "@heir//lib/Dialect/ModArith/IR:Dialect", + "@heir//lib/Dialect/ModArith/Transforms", + "@heir//lib/Dialect/ModArith/Transforms:ConvertToMac", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Dialect/Openfhe/Transforms", "@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index cc127d984..e93d6123f 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -3,6 +3,7 @@ #include #include +#include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" #include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h" #include "lib/Dialect/BGV/Conversions/BGVToOpenfhe/BGVToOpenfhe.h" @@ -24,6 +25,7 @@ #include "lib/Dialect/Mgmt/IR/MgmtDialect.h" #include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h" #include "lib/Dialect/ModArith/IR/ModArithDialect.h" +#include "lib/Dialect/ModArith/Transforms/Passes.h" #include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "lib/Dialect/Openfhe/Transforms/Passes.h" #include "lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.h" @@ -146,7 +148,7 @@ int main(int argc, char **argv) { registry.insert<::mlir::linalg::LinalgDialect>(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -162,7 +164,7 @@ int main(int argc, char **argv) { // Upstream passes used by HEIR // Converting to LLVM - arith::registerConvertArithToLLVMInterface(registry); + mlir::arith::registerConvertArithToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); index::registerConvertIndexToLLVMInterface(registry); @@ -202,8 +204,8 @@ int main(int argc, char **argv) { // Bufferization and external models bufferization::registerBufferizationPasses(); - arith::registerBufferizableOpInterfaceExternalModels(registry); - arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); cf::registerBufferizableOpInterfaceExternalModels(registry); @@ -263,8 +265,12 @@ int main(int argc, char **argv) { // Register internal pipeline #endif + registerTosaToArithPipeline(); + // Dialect conversion passes in HEIR mod_arith::registerModArithToArithPasses(); + heir::arith::registerArithToModArithPasses(); + mod_arith::registerConvertToMacPass(); bgv::registerBGVToLWEPasses(); bgv::registerBGVToLattigoPasses(); bgv::registerBGVToOpenfhePasses();