diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp new file mode 100644 index 0000000000..5e7c7ea5d8 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp @@ -0,0 +1,415 @@ +#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h" + +#include + +#include + +#include "lib/Dialect/CGGI/IR/CGGIDialect.h" +#include "lib/Dialect/CGGI/IR/CGGIOps.h" +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Utils/ConversionUtils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir::arith { + +#define GEN_PASS_DEF_ARITHTOCGGIQUART +#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h.inc" + +static constexpr unsigned maxIntWidth = 16; + +// ToDo: General funcntion: build trivial Op +// Get maxIntWidth Type + +static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type, + MLIRContext *ctx) { + return lwe::LWECiphertextType::get(ctx, + lwe::UnspecifiedBitFieldEncodingAttr::get( + ctx, type.getIntOrFloatBitWidth()), + lwe::LWEParamsAttr()); +} + +static std::optional convertArithToCGGIQuartType(IntegerType type, + MLIRContext *ctx) { + auto lweType = lwe::LWECiphertextType::get( + ctx, lwe::UnspecifiedBitFieldEncodingAttr::get(ctx, maxIntWidth), + lwe::LWEParamsAttr()); + + float width = type.getWidth(); + float realWidth = maxIntWidth >> 1; + + uint8_t nbChunks = ceil(width / realWidth); + + if (width > 64) return std::nullopt; + + return RankedTensorType::get({nbChunks}, lweType); +} + +static std::optional convertArithLikeToCGGIQuartType(ShapedType type, + MLIRContext *ctx) { + if (auto arithType = llvm::dyn_cast(type.getElementType())) { + float width = arithType.getWidth(); + float realWidth = maxIntWidth >> 1; + + uint8_t nbChunks = ceil(width / realWidth); + + if (width > 64) return std::nullopt; + + if (arithType.getIntOrFloatBitWidth() == maxIntWidth) + return convertArithToCGGIQuartType(arithType, ctx); + + auto newShape = to_vector(type.getShape()); + newShape.push_back(nbChunks); + + if (isa(type)) { + return RankedTensorType::get( + newShape, IntegerType::get(type.getContext(), maxIntWidth)); + } + + if (isa(type)) { + return MemRefType::get(newShape, + IntegerType::get(type.getContext(), maxIntWidth)); + } + } + return type; +} + +class ArithToCGGIQuartTypeConverter : public TypeConverter { + public: + ArithToCGGIQuartTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + + // Convert Integer types to LWE ciphertext types + addConversion([ctx](IntegerType type) -> std::optional { + return convertArithToCGGIQuartType(type, ctx); + }); + + addConversion([ctx](ShapedType type) -> std::optional { + return convertArithLikeToCGGIQuartType(type, ctx); + }); + } +}; + +static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) { + auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth >> 1); + auto intAttr = b.getIntegerAttr(maxWideIntType, value); + + auto encoding = + lwe::UnspecifiedBitFieldEncodingAttr::get(b.getContext(), maxIntWidth); + auto lweType = lwe::LWECiphertextType::get(b.getContext(), encoding, + lwe::LWEParamsAttr()); + + return b.create(lweType, intAttr); +} + +/// Extracts the `input` tensor slice with elements at the last dimension offset +/// by `lastOffset`. Returns a value of tensor type with the last dimension +/// reduced to x1 or fully scalarized, e.g.: +/// - tensor<2xi16> --> i16 +static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, + Location loc, Value input, + int64_t lastOffset) { + ArrayRef shape = cast(input.getType()).getShape(); + assert(lastOffset < shape.back() && "Offset out of bounds"); + + // Create index element + auto intAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), lastOffset); + auto constantOp = rewriter.create(loc, intAttr); + SmallVector indices; + indices.push_back(constantOp.getResult()); + + // Scalarize the result in case of 1D tensors. + if (shape.size() == 1) { + return rewriter.create(loc, input, indices); + } + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + offsets.back() = rewriter.getIndexAttr(lastOffset); + SmallVector sizes(shape.size()); + sizes.back() = rewriter.getIndexAttr(1); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + + return rewriter.create(loc, input, offsets, sizes, + strides); +} + +/// Extracts four tensor slices from the `input` whose type is `tensor<...x4T>`, +/// with the first element at offset 0, second element at offset 1 and so on. +static SmallVector extractLastDimHalves( + ConversionPatternRewriter &rewriter, Location loc, Value input) { + auto tenShape = cast(input.getType()).getShape(); + auto nbChunks = tenShape.back(); + SmallVector newTrivialOps; + + for (int i = 0; i < nbChunks; ++i) { + newTrivialOps.push_back(extractLastDimSlice(rewriter, loc, input, i)); + } + + return newTrivialOps; +}; + +static Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, + Type type, int64_t value) { + unsigned elementBitWidth = 0; + if (auto lweTy = dyn_cast(type)) + elementBitWidth = + cast(lweTy.getEncoding()) + .getCleartextBitwidth(); + else + elementBitWidth = maxIntWidth; + + auto apValue = APInt(elementBitWidth, value); + + auto maxWideIntType = + IntegerType::get(builder.getContext(), maxIntWidth >> 1); + auto intAttr = builder.getIntegerAttr(maxWideIntType, value); + + return builder.create(loc, type, intAttr); +} + +static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, + Location loc, Value source, Value dest, + int64_t lastOffset) { + ArrayRef shape = cast(dest.getType()).getShape(); + assert(lastOffset < shape.back() && "Offset out of bounds"); + + // // Handle scalar source. + auto intAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), lastOffset); + auto constantOp = rewriter.create(loc, intAttr); + SmallVector indices; + indices.push_back(constantOp.getResult()); + + return rewriter.create(loc, source, dest, indices); +} + +/// Constructs a new tensor of type `resultType` by creating a series of +/// insertions of `resultComponents`, each at the next offset of the last tensor +/// dimension. +/// When all `resultComponents` are scalars, the result type is `tensor`; +/// when `resultComponents` are `tensor<...x1xT>`s, the result type is +/// `tensor<...xNxT>`, where `N` is the number of `resultComponents`. +static Value constructResultTensor(ConversionPatternRewriter &rewriter, + Location loc, RankedTensorType resultType, + ValueRange resultComponents) { + Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0); + for (auto [i, component] : llvm::enumerate(resultComponents)) + resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i); + + return resultVec; +} + +struct ConvertQuartConstantOp + : public OpConversionPattern { + ConvertQuartConstantOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (isa(op.getValue().getType())) { + return failure(); + } + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Type oldType = op.getType(); + auto newType = getTypeConverter()->convertType(oldType); + auto acutalBitWidth = maxIntWidth >> 1; + + if (!newType) + return rewriter.notifyMatchFailure( + op, llvm::formatv("unsupported type: {0}", op.getType())); + + Attribute oldValue = op.getValueAttr(); + auto tenShape = newType.getShape(); + auto nbChunks = tenShape.back(); + SmallVector newTrivialOps; + + if (auto intAttr = dyn_cast(oldValue)) { + for (uint8_t i = 0; i < nbChunks; i++) { + APInt intChunck = + intAttr.getValue().extractBits(acutalBitWidth, i * acutalBitWidth); + + auto encrypt = createTrivialOpMaxWidth(b, intChunck.getSExtValue()); + newTrivialOps.push_back(encrypt); + } + + Value resultVec = + constructResultTensor(rewriter, op.getLoc(), newType, newTrivialOps); + rewriter.replaceOp(op, resultVec); + + return success(); + } + } +}; + +template +struct ConvertQuartExt final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // Since each type inside the program is a tensor with 4 elements, we can + // simply return the input tensor as the result. The generated code will later + // be removed by the CSE pass. + + LogicalResult matchAndRewrite( + ArithExtOp op, typename ArithExtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + auto newResultTy = cast( + convertArithToCGGIQuartType(cast(op.getResult().getType()), + op.getContext()) + .value()); + auto newInTy = cast( + convertArithToCGGIQuartType(cast(op.getIn().getType()), + op.getContext()) + .value()); + + auto resultChunks = newResultTy.getShape().back(); + auto inChunks = newInTy.getShape().back(); + + if (resultChunks > inChunks) { + auto paddingFactor = resultChunks - inChunks; + + SmallVector low, high; + low.push_back(rewriter.getIndexAttr(0)); + high.push_back(rewriter.getIndexAttr(paddingFactor)); + + auto padValue = createTrivialOpMaxWidth(b, 0); + + auto resultVec = b.create(newResultTy, adaptor.getIn(), + low, high, padValue, + /*nofold=*/true); + + rewriter.replaceOp(op, resultVec); + return success(); + } + } +}; + +struct ConvertQuartAddI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto newTy = + getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + + SmallVector splitLhs = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + SmallVector splitRhs = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + assert(splitLhs.size() == splitRhs.size() && "Mismatched tensor sizes"); + + // Actual type of the underlying elements; we use half the width. + // Create Constant + auto intAttr = IntegerAttr::get(rewriter.getI8Type(), maxIntWidth >> 1); + + auto elemType = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth), op->getContext()); + auto realTy = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext()); + + auto constantOp = b.create(intAttr); + + SmallVector carries; + SmallVector outputs; + + for (int i = 0; i < splitLhs.size(); ++i) { + auto lowSum = b.create(splitLhs[i], splitRhs[i]); + auto outputLsb = b.create(op.getLoc(), realTy, lowSum); + auto outputLsbHigh = + b.create(op.getLoc(), elemType, outputLsb); + + // Now all the outputs are 16b elements, wants presentation of 4x8b + if (i != splitLhs.size() - 1) { + auto carry = b.create(elemType, lowSum, constantOp); + carries.push_back(carry); + } + + if (i == 0) { + outputs.push_back(outputLsbHigh); + } else { + auto high = b.create(outputLsbHigh, carries[i - 1]); + outputs.push_back(high); + } + } + + Value resultVec = constructResultTensor(rewriter, loc, newTy, outputs); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + +struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + ArithToCGGIQuartTypeConverter typeConverter(context); + + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect(opLegalCallback); + + target.addDynamicallyLegalOp< + memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::LoadOp, + memref::SubViewOp, memref::CopyOp, affine::AffineLoadOp, + affine::AffineStoreOp, tensor::FromElementsOp, tensor::ExtractOp>( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + target.addDynamicallyLegalOp( + [](mlir::arith::ConstantOp op) { + // Allow use of constant if it is used to denote the size of a shift + bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + }); + return (isa(op.getValue().getType()) || (usedByShift)); + }); + + patterns.add< + ConvertQuartConstantOp, ConvertQuartExt, + ConvertQuartExt, ConvertQuartAddI, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny >( + typeConverter, context); + + addStructuralConversionPatterns(typeConverter, patterns, target); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace mlir::heir::arith diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h new file mode 100644 index 0000000000..1c634adec5 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_ARITHTOCGGIQUART_H_ +#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_ARITHTOCGGIQUART_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir::arith { + +#define GEN_PASS_DECL +#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h.inc" + +} // namespace mlir::heir::arith + +#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_ARITHTOCGGIQUART_H_ diff --git a/lib/Dialect/Arith/Transforms/Passes.td b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.td similarity index 56% rename from lib/Dialect/Arith/Transforms/Passes.td rename to lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.td index 03f57accee..0ecdf00a36 100644 --- a/lib/Dialect/Arith/Transforms/Passes.td +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.td @@ -1,16 +1,16 @@ -#ifndef LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_ -#define LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_ +#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_TD_ +#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_TD_ include "mlir/Pass/PassBase.td" -def QuarterWideInt : Pass<"arith-quarter-wide-int"> { - - let summary = "Convert high precision arithmetic operations to a sequence of lower precision operations"; - let description = [{ +def ArithToCGGIQuart : Pass<"arith-to-cggi-quart"> { + let summary = "Lower `arith` to `cggi` dialect and divide each operation into smaller parts."; + let description = [{ This pass converts high precision arithmetic operations, i.e. operations on 32 bit integer, into a sequence of lower precision operations, i.e 8b operations. Currently, the pass splits the 32b integer into four 8b integers, using the tensor dialect. These smaller integers are stored in an 16b integer, so that we don't lose the carry information. + This pass converts the `arith` dialect to the `cggi` dialect. Based on the `arith-emulate-wide-int` pass from the MLIR arith dialect. @@ -18,8 +18,10 @@ def QuarterWideInt : Pass<"arith-quarter-wide-int"> { }]; let dependentDialects = [ "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", "mlir::tensor::TensorDialect", + "mlir::heir::cggi::CGGIDialect", ]; } -#endif // LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_ +#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGIQUART_ARITHTOCGGIQUART_TD_ diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/BUILD b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/BUILD new file mode 100644 index 0000000000..a4f1b43e42 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/BUILD @@ -0,0 +1,46 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ArithToCGGIQuart", + srcs = ["ArithToCGGIQuart.cpp"], + hdrs = ["ArithToCGGIQuart.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/CGGI/IR:Dialect", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ArithToCGGIQuart", + ], + "ArithToCGGIQuart.h.inc", + ), + ( + ["-gen-pass-doc"], + "ArithToCGGIQuart.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ArithToCGGIQuart.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Dialect/Arith/Transforms/BUILD b/lib/Dialect/Arith/Transforms/BUILD deleted file mode 100644 index f74fc9c222..0000000000 --- a/lib/Dialect/Arith/Transforms/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") - -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "Transforms", - hdrs = ["Passes.h"], - deps = [ - ":QuarterWideInt", - ":pass_inc_gen", - ], -) - -cc_library( - name = "QuarterWideInt", - srcs = ["QuarterWideInt.cpp"], - hdrs = ["QuarterWideInt.h"], - deps = [ - ":pass_inc_gen", - "@heir//lib/Utils:ConversionUtils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - ], -) - -add_heir_transforms( - header_filename = "Passes.h.inc", - pass_name = "Arith", - td_file = "Passes.td", -) diff --git a/lib/Dialect/Arith/Transforms/Passes.h b/lib/Dialect/Arith/Transforms/Passes.h deleted file mode 100644 index d5ff03ea36..0000000000 --- a/lib/Dialect/Arith/Transforms/Passes.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_ -#define LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_ - -#include "lib/Dialect/Arith/Transforms/QuarterWideInt.h" - -namespace mlir { -namespace heir { -namespace arith { - -#define GEN_PASS_REGISTRATION -#include "lib/Dialect/Arith/Transforms/Passes.h.inc" - -} // namespace arith -} // namespace heir -} // namespace mlir - -#endif // LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/Arith/Transforms/QuarterWideInt.cpp b/lib/Dialect/Arith/Transforms/QuarterWideInt.cpp deleted file mode 100644 index 1b6a923a01..0000000000 --- a/lib/Dialect/Arith/Transforms/QuarterWideInt.cpp +++ /dev/null @@ -1,490 +0,0 @@ -#include "lib/Dialect/Arith/Transforms/QuarterWideInt.h" - -#include - -#include "lib/Utils/ConversionUtils.h" -#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project - -namespace mlir { -namespace heir { -namespace arith { - -#define GEN_PASS_DEF_QUARTERWIDEINT -#include "lib/Dialect/Arith/Transforms/Passes.h.inc" - -static constexpr unsigned maxIntWidth = 16; - -class QuarterWideTypeConverter : public TypeConverter { - public: - QuarterWideTypeConverter(MLIRContext *ctx) { - // Allow unknown types. - addConversion([](Type ty) -> std::optional { return ty; }); - - // Scalar case. - addConversion([](IntegerType ty) -> std::optional { - unsigned width = ty.getWidth(); - if (width <= maxIntWidth) return ty; - - // i2N --> tensor<4xiN> - if (width == 2 * maxIntWidth) - return RankedTensorType::get( - 4, IntegerType::get(ty.getContext(), maxIntWidth)); - - return std::nullopt; - }); - - // tensor case. - addConversion([](ShapedType ty) -> std::optional { - auto intTy = dyn_cast(ty.getElementType()); - if (!intTy) return ty; - - unsigned width = intTy.getWidth(); - if (width <= maxIntWidth) return ty; - - // tensor<...xi2N> --> tensor<...x4xiN> - if (width == 2 * maxIntWidth) { - auto newShape = to_vector(ty.getShape()); - newShape.push_back(4); - return RankedTensorType::get( - newShape, IntegerType::get(ty.getContext(), maxIntWidth)); - } - return std::nullopt; - }); - } -}; - -//===----------------------------------------------------------------------===// -// Common Helper Functions -//===----------------------------------------------------------------------===// - -/// Returns the number divided into four chunks of N/2 bits from `value`, where -/// N = `newBitWidth/2`. Treats `value` as a 2*N bits-wide integer. The bottom -/// bits are returned in the first pair element, while the top bits in the -/// fourth one. -std::tuple getQuarters(const APInt &value, - unsigned newBitWidth) { - auto acutalBitWidth = newBitWidth >> 1; - - APInt low = value.extractBits(acutalBitWidth, 0); - APInt midLow = value.extractBits(acutalBitWidth, acutalBitWidth); - APInt midHigh = value.extractBits(acutalBitWidth, 2 * acutalBitWidth); - APInt high = value.extractBits(acutalBitWidth, 3 * acutalBitWidth); - return {std::move(low), std::move(midLow), std::move(midHigh), - std::move(high)}; -} - -/// Returns the type with the last (innermost) dimension reduced to x1. -/// Scalarizes 1D tensor inputs to match how we extract/insert tensor values, -/// e.g.: -/// - tensor<3x2xi16> --> tensor<3x1xi16> -/// - tensor<2xi16> --> i16 -Type reduceInnermostDim(RankedTensorType type) { - if (type.getShape().size() == 1) return type.getElementType(); - - auto newShape = to_vector(type.getShape()); - newShape.back() = 1; - return RankedTensorType::get(newShape, type.getElementType()); -} - -/// Extracts the `input` tensor slice with elements at the last dimension offset -/// by `lastOffset`. Returns a value of tensor type with the last dimension -/// reduced to x1 or fully scalarized, e.g.: -/// - tensor<2xi16> --> i16 -Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, - Value input, int64_t lastOffset) { - ArrayRef shape = cast(input.getType()).getShape(); - assert(lastOffset < shape.back() && "Offset out of bounds"); - - // Create index element - auto intAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), lastOffset); - auto constantOp = rewriter.create(loc, intAttr); - SmallVector indices; - indices.push_back(constantOp.getResult()); - - // Scalarize the result in case of 1D tensors. - if (shape.size() == 1) { - return rewriter.create(loc, input, indices); - } - - SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); - offsets.back() = rewriter.getIndexAttr(lastOffset); - SmallVector sizes(shape.size()); - sizes.back() = rewriter.getIndexAttr(1); - SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); - - return rewriter.create(loc, input, offsets, sizes, - strides); -} - -/// Extracts four tensor slices from the `input` whose type is `tensor<...x4T>`, -/// with the first element at offset 0, second element at offset 1 and so on. -std::tuple extractLastDimHalves( - ConversionPatternRewriter &rewriter, Location loc, Value input) { - return {extractLastDimSlice(rewriter, loc, input, 0), - extractLastDimSlice(rewriter, loc, input, 1), - extractLastDimSlice(rewriter, loc, input, 2), - extractLastDimSlice(rewriter, loc, input, 3)}; -} - -/// Inserts the `source` tensor slice into the `dest` tensor at offset -/// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is -/// a 1D tensor. -Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, - Value source, Value dest, int64_t lastOffset) { - ArrayRef shape = cast(dest.getType()).getShape(); - assert(lastOffset < shape.back() && "Offset out of bounds"); - - // Handle scalar source. - if (isa(source.getType())) { - auto intAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), lastOffset); - auto constantOp = rewriter.create(loc, intAttr); - SmallVector indices; - indices.push_back(constantOp.getResult()); - - return rewriter.create(loc, source, dest, indices); - } - - SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); - offsets.back() = rewriter.getIndexAttr(lastOffset); - SmallVector sizes(shape.size()); - sizes.back() = rewriter.getIndexAttr(1); - SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); - - return rewriter.create(loc, source, dest, offsets, - sizes, strides); -} - -Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, - int64_t value) { - unsigned elementBitWidth = 0; - if (auto intTy = dyn_cast(type)) - elementBitWidth = intTy.getWidth(); - else - elementBitWidth = cast(type).getElementTypeBitWidth(); - - auto apValue = APInt(elementBitWidth, value); - - TypedAttr attr; - if (isa(type)) { - attr = builder.getIntegerAttr(type, apValue); - } else { - auto vecTy = cast(type); - attr = SplatElementsAttr::get(vecTy, apValue); - } - - return builder.create(loc, attr); -} - -/// Constructs a new tensor of type `resultType` by creating a series of -/// insertions of `resultComponents`, each at the next offset of the last tensor -/// dimension. -/// When all `resultComponents` are scalars, the result type is `tensor`; -/// when `resultComponents` are `tensor<...x1xT>`s, the result type is -/// `tensor<...xNxT>`, where `N` is the number of `resultComponents`. -Value constructResultTensor(ConversionPatternRewriter &rewriter, Location loc, - RankedTensorType resultType, - ValueRange resultComponents) { - Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0); - for (auto [i, component] : llvm::enumerate(resultComponents)) - resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i); - - return resultVec; -} - -struct ConvertAddI final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::arith::AddIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - - auto newTy = - getTypeConverter()->convertType(op.getType()); - if (!newTy) - return rewriter.notifyMatchFailure( - loc, llvm::formatv("unsupported type: {0}", op.getType())); - - Type elemTy = reduceInnermostDim(newTy); - - auto [lhsElem0, lhsElem1, lhsElem2, lhsElem3] = - extractLastDimHalves(rewriter, loc, adaptor.getLhs()); - auto [rhsElem0, rhsElem1, rhsElem2, rhsElem3] = - extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - - // Actual type of the underlying elements; we use half the width. - auto realTy = IntegerType::get(op.getContext(), maxIntWidth >> 1); - // Create Constant - auto intAttr = rewriter.getIntegerAttr(elemTy, maxIntWidth >> 1); - auto constantOp = b.create(intAttr); - - auto lowSum0 = b.create(lhsElem0, rhsElem0); - auto lowSum1 = b.create(lhsElem1, rhsElem1); - auto lowSum2 = b.create(lhsElem2, rhsElem2); - auto lowSum3 = b.create(lhsElem3, rhsElem3); - - auto output0Lsb = b.create(realTy, lowSum0); - auto output0LsbHigh = b.create(elemTy, output0Lsb); - - auto output1Lsb = b.create(realTy, lowSum1); - auto output1LsbHigh = b.create(elemTy, output1Lsb); - - auto output2Lsb = b.create(realTy, lowSum2); - auto output2LsbHigh = b.create(elemTy, output2Lsb); - - auto output3Lsb = b.create(realTy, lowSum3); - auto output3LsbHigh = b.create(elemTy, output3Lsb); - - // Now all the outputs are 16b elements, wants presentation of 4x8b - auto carry0 = - b.create(lowSum0, constantOp.getResult()); - auto carry1 = - b.create(lowSum1, constantOp.getResult()); - auto carry2 = - b.create(lowSum2, constantOp.getResult()); - - auto high1 = b.create(output1LsbHigh, carry0); - auto high2 = b.create(output2LsbHigh, carry1); - auto high3 = b.create(output3LsbHigh, carry2); - - Value resultVec = constructResultTensor( - rewriter, loc, newTy, {output0LsbHigh, high1, high2, high3}); - rewriter.replaceOp(op, resultVec); - return success(); - } -}; - -// Implemented using the Karatsuba algorithm -// https://en.wikipedia.org/wiki/Karatsuba_algorithm#Algorithm -struct ConvertMulI final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::arith::MulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - - auto newTy = - getTypeConverter()->convertType(op.getType()); - if (!newTy) - return rewriter.notifyMatchFailure( - loc, llvm::formatv("unsupported type: {0}", op.getType())); - - auto elemTy = reduceInnermostDim(newTy); - // Actual type of the underlying elements; we use half the width. - auto realTy = IntegerType::get(op.getContext(), maxIntWidth >> 1); - - // Create Constant - auto intAttr = rewriter.getIntegerAttr(elemTy, maxIntWidth >> 1); - auto constantOp = b.create(intAttr); - - auto [lhsElem0, lhsElem1, lhsElem2, lhsElem3] = - extractLastDimHalves(rewriter, loc, adaptor.getLhs()); - auto [rhsElem0, rhsElem1, rhsElem2, rhsElem3] = - extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - - // TODO: Implement the real Karatsuba algorithm for 4x4 multiplication. - // First part of Karatsuba algorithm - auto z00 = b.create(lhsElem0, rhsElem0); - auto z02 = b.create(lhsElem1, rhsElem1); - auto z01_p1 = b.create(lhsElem0, lhsElem1); - auto z01_p2 = b.create(rhsElem0, rhsElem1); - auto z01_m = b.create(z01_p1, z01_p2); - auto z01_s = b.create(z01_m, z00); - auto z01 = b.create(z01_s, z02); - - // Second part I of Karatsuba algorithm - auto z1a0 = b.create(lhsElem0, rhsElem2); - auto z1a2 = b.create(lhsElem1, rhsElem3); - auto z1a1_p1 = b.create(lhsElem0, lhsElem1); - auto z1a1_p2 = b.create(rhsElem2, rhsElem3); - auto z1a1_m = b.create(z1a1_p1, z1a1_p2); - auto z1a1_s = b.create(z1a1_m, z1a0); - auto z1a1 = b.create(z1a1_s, z1a2); - - // Second part II of Karatsuba algorithm - auto z1b0 = b.create(lhsElem2, rhsElem0); - auto z1b2 = b.create(lhsElem3, rhsElem1); - auto z1b1_p1 = b.create(lhsElem2, lhsElem3); - auto z1b1_p2 = b.create(rhsElem0, rhsElem1); - auto z1b1_m = b.create(z1b1_p1, z1b1_p2); - auto z1b1_s = b.create(z1b1_m, z1b0); - auto z1b1 = b.create(z1b1_s, z1b2); - - auto out2Kara = b.create(z1a0, z1b0); - auto out2Carry = b.create(out2Kara, z02); - auto out3Carry = b.create(z1a1, z1b1); - - // Output are now all 16b elements, wants presentation of 4x8b - auto output0Lsb = b.create(realTy, z00); - auto output0LsbHigh = b.create(elemTy, output0Lsb); - auto output0Msb = - b.create(z00, constantOp.getResult()); - - auto output1Lsb = b.create(realTy, z01); - auto output1LsbHigh = b.create(elemTy, output1Lsb); - auto output1Msb = - b.create(z01, constantOp.getResult()); - - auto output2Lsb = b.create(realTy, out2Carry); - auto output2LsbHigh = b.create(elemTy, output2Lsb); - auto output2Msb = - b.create(out2Carry, constantOp.getResult()); - - auto output3Lsb = b.create(realTy, out3Carry); - auto output3LsbHigh = b.create(elemTy, output3Lsb); - - auto output1 = b.create(output1LsbHigh, output0Msb); - auto output2 = b.create(output2LsbHigh, output1Msb); - auto output3 = b.create(output3LsbHigh, output2Msb); - - Value resultVec = constructResultTensor( - rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3}); - rewriter.replaceOp(op, resultVec); - return success(); - } -}; - -struct ConvertArithConstant final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::arith::ConstantOp op, OpAdaptor, - ConversionPatternRewriter &rewriter) const override { - Type oldType = op.getType(); - auto newType = getTypeConverter()->convertType(oldType); - - if (!newType) - return rewriter.notifyMatchFailure( - op, llvm::formatv("unsupported type: {0}", op.getType())); - - unsigned newBitWidth = newType.getElementTypeBitWidth(); - Attribute oldValue = op.getValueAttr(); - - if (auto intAttr = dyn_cast(oldValue)) { - auto [low, midLow, midHigh, high] = - getQuarters(intAttr.getValue(), newBitWidth); - auto newAttr = - DenseElementsAttr::get(newType, {low, midLow, midHigh, high}); - rewriter.replaceOpWithNewOp(op, newAttr); - return success(); - } - - if (auto splatAttr = dyn_cast(oldValue)) { - auto [low, midLow, midHigh, high] = - getQuarters(splatAttr.getSplatValue(), newBitWidth); - int64_t numSplatElems = splatAttr.getNumElements(); - SmallVector values; - values.reserve(numSplatElems * 4); - for (int64_t i = 0; i < numSplatElems; ++i) { - values.push_back(low); - values.push_back(midLow); - values.push_back(midHigh); - values.push_back(high); - } - - auto attr = DenseElementsAttr::get(newType, values); - rewriter.replaceOpWithNewOp(op, attr); - return success(); - } - - if (auto elemsAttr = dyn_cast(oldValue)) { - int64_t numElems = elemsAttr.getNumElements(); - SmallVector values; - values.reserve(numElems * 4); - for (const APInt &origVal : elemsAttr.getValues()) { - auto [low, midLow, midHigh, high] = getQuarters(origVal, newBitWidth); - values.push_back(std::move(low)); - values.push_back(std::move(midLow)); - values.push_back(std::move(midHigh)); - values.push_back(std::move(high)); - } - - auto attr = DenseElementsAttr::get(newType, values); - rewriter.replaceOpWithNewOp(op, attr); - return success(); - } - - return rewriter.notifyMatchFailure(op.getLoc(), - "unhandled constant attribute"); - } -}; - -struct ConvertExtUI final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - - auto newTy = - getTypeConverter()->convertType(op.getType()); - - if (!newTy) - return rewriter.notifyMatchFailure( - loc, llvm::formatv("unsupported type: {0}", op.getType())); - - Value resultVec = constructResultTensor(rewriter, loc, newTy, {op.getIn()}); - rewriter.replaceOp(op, resultVec); - rewriter.replaceOp(op, resultVec); - return success(); - } -}; - -struct QuarterWideInt : impl::QuarterWideIntBase { - using QuarterWideIntBase::QuarterWideIntBase; - - void runOnOperation() override { - MLIRContext *context = &getContext(); - Operation *op = getOperation(); - RewritePatternSet patterns(context); - QuarterWideTypeConverter typeConverter(context); - - ConversionTarget target(*context); - target.addDynamicallyLegalOp([&typeConverter](Operation *op) { - return typeConverter.isLegal(cast(op).getFunctionType()); - }); - auto opLegalCallback = [&typeConverter](Operation *op) { - return typeConverter.isLegal(op); - }; - - target.addDynamicallyLegalOp(opLegalCallback); - target.addDynamicallyLegalDialect(opLegalCallback); - - addStructuralConversionPatterns(typeConverter, patterns, target); - - patterns.add( - typeConverter, context); - - if (failed(applyPartialConversion(op, target, std::move(patterns)))) - signalPassFailure(); - - // Remove the uncessary tensor ops between each converted arith operation. - OpPassManager pipeline("builtin.module"); - pipeline.addPass(createCSEPass()); - (void)runPipeline(pipeline, getOperation()); - } -}; - -} // namespace arith -} // namespace heir -} // namespace mlir diff --git a/lib/Dialect/Arith/Transforms/QuarterWideInt.h b/lib/Dialect/Arith/Transforms/QuarterWideInt.h deleted file mode 100644 index 2a9afea605..0000000000 --- a/lib/Dialect/Arith/Transforms/QuarterWideInt.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef LIB_DIALECT_ARITH_TRANSFORMS_QUARTERWIDEINT_H_ -#define LIB_DIALECT_ARITH_TRANSFORMS_QUARTERWIDEINT_H_ - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace heir { -namespace arith { - -#define GEN_PASS_DECL_QUARTERWIDEINT -#include "lib/Dialect/Arith/Transforms/Passes.h.inc" - -} // namespace arith -} // namespace heir -} // namespace mlir - -#endif // LIB_DIALECT_ARITH_TRANSFORMS_QUARTERWIDEINT_H_ diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index facc99d2b6..a43d3a2376 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -42,7 +42,7 @@ constexpr int kAndLut = 8; constexpr int kOrLut = 14; constexpr int kXorLut = 6; -Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { +static Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { // Only supporting unsigned types because the LWE dialect does not have a // notion of signedness. switch (width) { @@ -101,7 +101,7 @@ class CGGIToTfheRustTypeConverter : public TypeConverter { /// Returns the Value corresponding to a server key in the FuncOp containing /// this op. -FailureOr getContextualServerKey(Operation *op) { +static FailureOr getContextualServerKey(Operation *op) { Value serverKey = op->getParentOfType() .getBody() .getBlocks() @@ -156,6 +156,37 @@ struct AddServerKeyArg : public OpConversionPattern { } }; +struct AddServerKeyArgCall : public OpConversionPattern { + AddServerKeyArgCall(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr sk = getContextualServerKey(op.getOperation()); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + llvm::SmallVector newOperands; + newOperands.reserve(adaptor.getOperands().size() + 1); + newOperands.push_back(sk.value()); + for (auto t : adaptor.getOperands()) { + newOperands.push_back(t); + } + + // // Set the updated operand list on the operation + auto newCallOp = b.create( + op.getLoc(), adaptor.getCallee(), + getTypeConverter()->convertType(op.getResult(0).getType()), + newOperands); + rewriter.replaceOp(op, newCallOp); + + return success(); + } +}; + /// Convert a Lut3Op to: /// - generate_lookup_table /// - scalar_left_shift @@ -232,8 +263,9 @@ struct ConvertLut2Op : public OpConversionPattern { } }; -LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs, - ConversionPatternRewriter &rewriter, int lut) { +static LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter, + int lut) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); FailureOr result = getContextualServerKey(op); if (failed(result)) return result; diff --git a/tests/Dialect/Arith/Transforms/quarter_wide_int/BUILD b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/BUILD similarity index 100% rename from tests/Dialect/Arith/Transforms/quarter_wide_int/BUILD rename to tests/Dialect/Arith/Conversions/ArithToCGGIQuart/BUILD diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir new file mode 100644 index 0000000000..363f50da95 --- /dev/null +++ b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir @@ -0,0 +1,10 @@ +// RUN: heir-opt --arith-to-cggi-quart %s | FileCheck %s + +// CHECK: return %[[RET:.*]] tensor<4x!lwe.lwe_ciphertext> +func.func @test_simple_split2(%arg0: i32, %arg1: i16) -> i32 { + %2 = arith.constant 31 : i16 + %5 = arith.addi %arg1, %2 : i16 + %6 = arith.extui %5 : i16 to i32 + %7 = arith.addi %arg0, %6 : i32 + return %6 : i32 +} diff --git a/tests/Dialect/Arith/Transforms/quarter_wide_int/quarter_wide.mlir b/tests/Dialect/Arith/Transforms/quarter_wide_int/quarter_wide.mlir deleted file mode 100644 index 08cfcdc404..0000000000 --- a/tests/Dialect/Arith/Transforms/quarter_wide_int/quarter_wide.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: heir-opt --arith-quarter-wide-int %s | FileCheck %s - -// CHECK-LABEL: func @test_simple_split -// CHCK-COUNT-9: arith.muli -// CHCK-COUNT-7: arith.addi -// CHCK-COUNT-3: arith.shrui -// CHCK-COUNT-3: arith.addi -func.func @test_simple_split(%arg0: i32, %arg1: i32) -> i32 { - %1 = arith.constant 522067228: i32 // Hex 1f1e1d1c - %2 = arith.constant 31 : i8 - %3 = arith.extui %2 : i8 to i32 - %4 = arith.muli %1, %arg1 : i32 - %5 = arith.addi %arg0, %3 : i32 - return %4 : i32 -} diff --git a/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean_small.mlir b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean_small.mlir new file mode 100644 index 0000000000..184b983ae5 --- /dev/null +++ b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean_small.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt --tosa-to-boolean-tfhe=abc-fast=true %s | FileCheck %s + +// A reduced dimension version of hello world to speed Yosys up. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + + func.func @main(%arg0: tensor<1x1xi8> {iree.identifier = "serving_default_dense_input:0", tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x3xi32> {iree.identifier = "StatefulPartitionedCall:0", tf_saved_model.index_path = ["dense_2"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} { + %4 = "tosa.const"() {value = dense<[0, 0, 5438]> : tensor<3xi32>} : () -> tensor<3xi32> + %5 = "tosa.const"() {value = dense<[[9], [54], [57]]> : tensor<3x1xi8>} : () -> tensor<3x1xi8> + %6 = "tosa.fully_connected"(%arg0, %5, %4) {quantization_info = #tosa.conv_quant} : (tensor<1x1xi8>, tensor<3x1xi8>, tensor<3xi32>) -> tensor<1x3xi32> + // CHECK: return + return %6 : tensor<1x3xi32> + } +} diff --git a/tools/BUILD b/tools/BUILD index 32add36248..ac9346089e 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -33,9 +33,8 @@ cc_binary( includes = ["include"], deps = [ "@heir//lib/Dialect/Arith/Conversions/ArithToCGGI", + "@heir//lib/Dialect/Arith/Conversions/ArithToCGGIQuart", "@heir//lib/Dialect/Arith/Conversions/ArithToModArith", - "@heir//lib/Dialect/Arith/Transforms", - "@heir//lib/Dialect/Arith/Transforms:QuarterWideInt", "@heir//lib/Dialect/BGV/Conversions/BGVToLWE", "@heir//lib/Dialect/BGV/Conversions/BGVToLattigo", "@heir//lib/Dialect/BGV/IR:Dialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 496fabf996..6865f4012c 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -4,8 +4,8 @@ #include #include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h" +#include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h" #include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" -#include "lib/Dialect/Arith/Transforms/Passes.h" #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" #include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h" #include "lib/Dialect/BGV/IR/BGVDialect.h" @@ -249,7 +249,6 @@ int main(int argc, char **argv) { mlir::arith::registerConvertArithToLLVMInterface(registry); // Custom passes in HEIR - heir::arith::registerArithPasses(); cggi::registerCGGIPasses(); lwe::registerLWEPasses(); mgmt::registerMgmtPasses(); @@ -308,6 +307,7 @@ int main(int argc, char **argv) { mod_arith::registerModArithToArithPasses(); mlir::heir::arith::registerArithToModArithPasses(); mlir::heir::arith::registerArithToCGGIPasses(); + mlir::heir::arith::registerArithToCGGIQuartPasses(); mod_arith::registerConvertToMacPass(); bgv::registerBGVToLWEPasses(); bgv::registerBGVToLattigoPasses();