From e6c79eda2aee1e3fae6a8784e8deba17c00d1310 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Sat, 21 Dec 2024 22:50:51 +0000 Subject: [PATCH] bgv-to-lattigo: lower client-interface and plain op --- .../Conversions/BGVToLattigo/BGVToLattigo.cpp | 91 +++++++-- .../Conversions/RlweToLattigo/RlweToLattigo.h | 177 +++++++++++++++++- lib/Dialect/Lattigo/IR/LattigoBGVOps.td | 2 +- .../ArithmeticPipelineRegistration.cpp | 33 ++++ .../ArithmeticPipelineRegistration.h | 2 + lib/Pipelines/BUILD | 1 + lib/Target/Lattigo/LattigoEmitter.cpp | 59 +++++- lib/Target/Lattigo/LattigoEmitter.h | 1 + lib/Utils/Utils.cpp | 10 + lib/Utils/Utils.h | 37 +++- .../bgv_to_lattigo/bgv_to_lattigo.mlir | 2 +- .../Lattigo/Emitters/emit_lattigo.mlir | 9 +- tests/Examples/lattigo/BUILD | 26 ++- tests/Examples/lattigo/binops.mlir | 2 +- tests/Examples/lattigo/binops_test.go | 55 +----- tests/Examples/lattigo/dot_product_8.mlir | 12 ++ tests/Examples/lattigo/dot_product_8_test.go | 58 ++++++ tools/heir-opt.cpp | 7 + 18 files changed, 494 insertions(+), 90 deletions(-) create mode 100644 tests/Examples/lattigo/dot_product_8.mlir create mode 100644 tests/Examples/lattigo/dot_product_8_test.go diff --git a/lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp b/lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp index fe4b216f53..02a6971815 100644 --- a/lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp +++ b/lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp @@ -21,6 +21,7 @@ #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project namespace mlir::heir::bgv { @@ -28,11 +29,18 @@ namespace mlir::heir::bgv { #include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h.inc" using ConvertAddOp = - ConvertRlweBinOp; + ConvertRlweBinOp; using ConvertSubOp = - ConvertRlweBinOp; + ConvertRlweBinOp; using ConvertMulOp = - ConvertRlweBinOp; + ConvertRlweBinOp; +using ConvertAddPlainOp = ConvertRlwePlainOp; +using ConvertSubPlainOp = ConvertRlwePlainOp; +using ConvertMulPlainOp = ConvertRlwePlainOp; + using ConvertRelinOp = ConvertRlweUnaryOp; @@ -44,6 +52,33 @@ using ConvertModulusSwitchOp = using ConvertRotateOp = ConvertRlweRotateOp; +using ConvertEncryptOp = + ConvertRlweUnaryOp; +using ConvertDecryptOp = + ConvertRlweUnaryOp; +using ConvertEncodeOp = + ConvertRlweEncodeOp; +using ConvertDecodeOp = + ConvertRlweDecodeOp; + +struct ConvertLWEReinterpretUnderlyingType + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::ReinterpretUnderlyingTypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // erase reinterpret underlying + rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp()); + return success(); + } +}; + struct BGVToLattigo : public impl::BGVToLattigoBase { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -53,30 +88,62 @@ struct BGVToLattigo : public impl::BGVToLattigoBase { ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalDialect(); - target.addIllegalOp(); + target + .addIllegalOp(); RewritePatternSet patterns(context); addStructuralConversionPatterns(typeConverter, patterns, target); target.addDynamicallyLegalOp([&](func::FuncOp op) { - bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 && - mlir::isa( - *op.getFunctionType().getInputs().begin()); + bool hasCryptoContextArg = + op.getFunctionType().getNumInputs() > 0 && + containsArgumentOfType< + lattigo::BGVEvaluatorType, lattigo::BGVEncoderType, + lattigo::RLWEEncryptorType, lattigo::RLWEDecryptorType>(op); + return typeConverter.isSignatureLegal(op.getFunctionType()) && typeConverter.isLegal(&op.getBody()) && (!containsDialects(op) || hasCryptoContextArg); }); - patterns.add, - ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertRelinOp, - ConvertModulusSwitchOp, ConvertRotateOp>(typeConverter, - context); + std::vector> evaluators; + + // param/encoder also needed for the main func + // as there might (not) be ct-pt operations + evaluators = { + {lattigo::BGVEvaluatorType::get(context), + containsDialects}, + {lattigo::BGVParameterType::get(context), + containsDialects}, + {lattigo::BGVEncoderType::get(context), + containsDialects}, + {lattigo::RLWEEncryptorType::get(context), + containsAnyOperations}, + {lattigo::RLWEDecryptorType::get(context), + containsAnyOperations}, + }; + + patterns.add(context, evaluators); + + patterns.add(typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); } + + // remove unused key args from function types + // in favor of encryptor/decryptor + RewritePatternSet postPatterns(context); + postPatterns.add>(context); + postPatterns.add>(context); + walkAndApplyPatterns(module, std::move(postPatterns)); } }; diff --git a/lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h b/lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h index 452c49ced0..4b7b166113 100644 --- a/lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h +++ b/lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h @@ -29,25 +29,36 @@ FailureOr getContextualEvaluator(Operation *op) { return result.value(); } -template struct AddEvaluatorArg : public OpConversionPattern { - AddEvaluatorArg(mlir::MLIRContext *context) - : OpConversionPattern(context, /* benefit= */ 2) {} + AddEvaluatorArg(mlir::MLIRContext *context, + const std::vector> &evaluators) + : OpConversionPattern(context, /* benefit= */ 2), + evaluators(evaluators) {} using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!containsDialects(op)) { - return failure(); + SmallVector selectedEvaluators; + + for (const auto &evaluator : evaluators) { + auto predicate = evaluator.second; + if (predicate(op)) { + selectedEvaluators.push_back(evaluator.first); + } + } + + if (selectedEvaluators.empty()) { + return success(); } - auto evaluatorType = EvaluatorType::get(getContext()); FunctionType originalType = op.getFunctionType(); llvm::SmallVector newTypes; - newTypes.reserve(originalType.getNumInputs() + 1); - newTypes.push_back(evaluatorType); + newTypes.reserve(originalType.getNumInputs() + selectedEvaluators.size()); + for (auto evaluatorType : selectedEvaluators) { + newTypes.push_back(evaluatorType); + } for (auto t : originalType.getInputs()) { newTypes.push_back(t); } @@ -57,10 +68,61 @@ struct AddEvaluatorArg : public OpConversionPattern { op.setType(newFuncType); Block &block = op.getBody().getBlocks().front(); - block.insertArgument(&block.getArguments().front(), evaluatorType, - op.getLoc()); + for (auto evaluatorType : llvm::reverse(selectedEvaluators)) { + block.insertArgument(&block.getArguments().front(), evaluatorType, + op.getLoc()); + } }); + return success(); + } + + private: + std::vector> evaluators; +}; + +template +struct RemoveKeyArg : public OpConversionPattern { + RemoveKeyArg(mlir::MLIRContext *context) + : OpConversionPattern(context, /* benefit= */ 2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector keyArgIndices; + Block &block = op.getBody().getBlocks().front(); + for (auto arg : block.getArguments()) { + if (mlir::isa(arg.getType()) && arg.getUses().empty()) { + keyArgIndices.push_back(arg.getArgNumber()); + } + } + + if (keyArgIndices.empty()) { + return success(); + } + FunctionType originalType = op.getFunctionType(); + llvm::SmallVector newTypes; + newTypes.reserve(originalType.getNumInputs()); + for (auto arg : block.getArguments()) { + if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) { + continue; + } + newTypes.push_back(arg.getType()); + } + auto newFuncType = + FunctionType::get(getContext(), newTypes, originalType.getResults()); + rewriter.modifyOpInPlace(op, [&] { + op.setType(newFuncType); + + Block &block = op.getBody().getBlocks().front(); + for (auto arg : block.getArguments()) { + if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) { + block.eraseArgument(arg.getArgNumber()); + } + } + }); return success(); } }; @@ -105,6 +167,25 @@ struct ConvertRlweBinOp : public OpConversionPattern { } }; +template +struct ConvertRlwePlainOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + PlainOp op, typename PlainOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = + getContextualEvaluator(op.getOperation()); + if (failed(result)) return result; + + Value evaluator = result.value(); + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + evaluator, adaptor.getCiphertextInput(), adaptor.getPlaintextInput()); + return success(); + } +}; + template struct ConvertRlweRotateOp : public OpConversionPattern { @@ -130,6 +211,82 @@ struct ConvertRlweRotateOp : public OpConversionPattern { } }; +template +struct ConvertRlweEncodeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + EncodeOp op, typename EncodeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = + getContextualEvaluator(op.getOperation()); + if (failed(result)) return result; + Value evaluator = result.value(); + + FailureOr result2 = + getContextualEvaluator(op.getOperation()); + if (failed(result2)) return result2; + Value params = result2.value(); + + auto alloc = rewriter.create( + op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()), + params); + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + evaluator, adaptor.getInput(), alloc); + return success(); + } +}; + +template +struct ConvertRlweDecodeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + DecodeOp op, typename DecodeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = + getContextualEvaluator(op.getOperation()); + if (failed(result)) return result; + Value evaluator = result.value(); + + auto outputType = op.getOutput().getType(); + RankedTensorType outputTensorType = dyn_cast(outputType); + bool isScalar = false; + if (!outputTensorType) { + isScalar = true; + outputTensorType = RankedTensorType::get({1}, outputType); + } + + APInt zero(getElementTypeOrSelf(outputType).getIntOrFloatBitWidth(), 0); + + auto constant = DenseElementsAttr::get(outputTensorType, zero); + + auto alloc = + rewriter.create(op.getLoc(), outputTensorType, constant); + + auto decodeOp = rewriter.create( + op.getLoc(), outputTensorType, evaluator, adaptor.getInput(), alloc); + + // TODO(#1174): the sin of lwe.reinterpret_underlying_type + if (isScalar) { + SmallVector indices; + auto index = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); + indices.push_back(index); + auto extract = rewriter.create( + op.getLoc(), decodeOp.getResult(), indices); + rewriter.replaceOp(op, extract.getResult()); + } else { + rewriter.replaceOp(op, decodeOp.getResult()); + } + return success(); + } +}; + } // namespace mlir::heir #endif // LIB_DIALECT_LWE_CONVERSIONS_RLWETOLATTIGOUTILS_RLWETOLATTIGO_H_ diff --git a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td index 47462b67fd..526390b06f 100644 --- a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td @@ -97,7 +97,7 @@ class Lattigo_BGVBinaryOp : let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, - Lattigo_RLWECiphertext:$rhs + AnyType:$rhs ); let results = (outs Lattigo_RLWECiphertext:$output); } diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index b91ec7ffba..91d119133c 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -4,6 +4,7 @@ #include #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" +#include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h" #include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Transforms/AddClientInterface.h" @@ -244,6 +245,38 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) { }; } +RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(const RLWEScheme scheme) { + return [=](OpPassManager &pm, const MlirToRLWEPipelineOptions &options) { + // lower to RLWE scheme + MlirToRLWEPipelineOptions overrideOptions; + overrideOptions.entryFunction = options.entryFunction; + overrideOptions.ciphertextDegree = options.ciphertextDegree; + overrideOptions.modulusSwitchBeforeFirstMul = + options.modulusSwitchBeforeFirstMul; + // use simpler client interface for Lattigo + overrideOptions.usePublicKey = false; + overrideOptions.oneValuePerHelperFn = false; + mlirToRLWEPipeline(pm, overrideOptions, scheme); + + // Convert to (common trivial subset of) LWE + switch (scheme) { + case RLWEScheme::bgvScheme: { + // TODO (#1193): Replace `--bgv-to-lwe` with `--bgv-common-to-lwe` + pm.addPass(bgv::createBGVToLWE()); + pm.addPass(bgv::createBGVToLattigo()); + break; + } + default: + llvm::errs() << "Unsupported RLWE scheme: " << scheme; + exit(EXIT_FAILURE); + } + + // Simplify, in case the lowering revealed redundancy + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + }; +} + void registerTosaToArithPipeline() { PassPipelineRegistration<>( "tosa-to-arith", "Arithmetic modules to arith tfhe-rs pipeline.", diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index dc6b33b13b..06ba75b924 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -60,6 +60,8 @@ RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme); RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(RLWEScheme scheme); +RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(RLWEScheme scheme); + void registerTosaToArithPipeline(); } // namespace mlir::heir diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index f5a3663aa7..37aae7d0e2 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -11,6 +11,7 @@ cc_library( hdrs = ["PipelineRegistration.h"], deps = [ "@heir//lib/Dialect/BGV/Conversions/BGVToLWE", + "@heir//lib/Dialect/BGV/Conversions/BGVToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", "@heir//lib/Dialect/LinAlg/Conversions/LinalgToTensorExt", "@heir//lib/Dialect/ModArith/Conversions/ModArithToArith", diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index aaa86d78c9..0a6f1e3fb0 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -52,7 +52,10 @@ LogicalResult LattigoEmitter::translate(Operation &op) { // Func ops .Case( [&](auto op) { return printOperation(op); }) + // Arith ops .Case([&](auto op) { return printOperation(op); }) + // Tensor ops + .Case([&](auto op) { return printOperation(op); }) // Lattigo ops .Case(op.getPlaintext().getDefiningOp()); + if (!newPlaintextOp) { + return failure(); + } + auto maxSlotsName = getName(newPlaintextOp.getParams()) + ".MaxSlots()"; + + auto packedName = getName(op.getValue()) + "_packed"; + os << packedName << " := make([]int64, "; + os << maxSlotsName << ")\n"; + os << "for i := range " << packedName << " {\n"; + os.indent(); + os << packedName << "[i] = int64(" << getName(op.getValue()) << "[i \% len(" + << getName(op.getValue()) << ")])\n"; + os.unindent(); + os << "}\n"; + + os << getName(op.getEncoder()) << ".Encode("; + os << packedName << ", "; + os << getName(op.getPlaintext()) << ")\n"; + os << getName(op.getEncoded()) << " := " << getName(op.getPlaintext()) + << "\n"; + return success(); } LogicalResult LattigoEmitter::printOperation(BGVDecodeOp op) { - return printEvalInplaceMethod(op.getDecoded(), op.getEncoder(), - op.getPlaintext(), op.getValue(), "Decode", - false); + os << getName(op.getEncoder()) << ".Decode("; + os << getName(op.getPlaintext()) << ", "; + os << getName(op.getValue()) << ")\n"; + + // type conversion from value to decoded + auto convertedName = getName(op.getDecoded()) + "_converted"; + os << convertedName << " := make(" << convertType(op.getDecoded().getType()) + << ", len(" << getName(op.getValue()) << "))\n"; + os << "for i := range " << getName(op.getValue()) << " {\n"; + os.indent(); + os << convertedName + << "[i] = " << convertType(getElementTypeOrSelf(op.getDecoded().getType())) + << "(" << getName(op.getValue()) << "[i])\n"; + os.unindent(); + os << "}\n"; + os << getName(op.getDecoded()) << " := " << convertedName << "\n"; + return success(); } LogicalResult LattigoEmitter::printOperation(BGVAddOp op) { diff --git a/lib/Target/Lattigo/LattigoEmitter.h b/lib/Target/Lattigo/LattigoEmitter.h index b92ca19bdd..a2e8e36576 100644 --- a/lib/Target/Lattigo/LattigoEmitter.h +++ b/lib/Target/Lattigo/LattigoEmitter.h @@ -55,6 +55,7 @@ class LattigoEmitter { LogicalResult printOperation(::mlir::func::ReturnOp op); LogicalResult printOperation(::mlir::func::CallOp op); LogicalResult printOperation(::mlir::arith::ConstantOp op); + LogicalResult printOperation(::mlir::tensor::ExtractOp op); // Lattigo ops LogicalResult printOperation(RLWENewEncryptorOp op); LogicalResult printOperation(RLWENewDecryptorOp op); diff --git a/lib/Utils/Utils.cpp b/lib/Utils/Utils.cpp index c096d7c2eb..f3ec5266eb 100644 --- a/lib/Utils/Utils.cpp +++ b/lib/Utils/Utils.cpp @@ -53,5 +53,15 @@ LogicalResult walkAndValidateValues(Operation *op, IsValidValueFn isValidValue, return res; } +bool containsArgumentOfType(Operation *op, TypePredicate predicate) { + return llvm::any_of(op->getRegions(), [&](Region ®ion) { + return llvm::any_of(region.getBlocks(), [&](Block &block) { + return llvm::any_of(block.getArguments(), [&](BlockArgument arg) { + return predicate(arg.getType()); + }); + }); + }); +} + } // namespace heir } // namespace mlir diff --git a/lib/Utils/Utils.h b/lib/Utils/Utils.h index dfa5df10f4..f79d18f7bf 100644 --- a/lib/Utils/Utils.h +++ b/lib/Utils/Utils.h @@ -18,11 +18,37 @@ typedef std::function OpPredicate; typedef std::function IsValidTypeFn; typedef std::function IsValidValueFn; +typedef std::function TypePredicate; + +typedef std::function DialectPredicate; + +template +OpPredicate OpEqual() { + return [](Operation *op) { return mlir::isa(op); }; +} + +template +TypePredicate TypeEqual() { + return [](const Type &type) { return mlir::isa(type); }; +} + +template +DialectPredicate DialectEqual() { + return [](Dialect *dialect) { return mlir::isa(dialect); }; +} + // Walks the given op, applying the predicate to traversed ops until the // predicate returns true, then returns the operation that matched, or // nullptr if there were no matches. Operation *walkAndDetect(Operation *op, OpPredicate predicate); +// specialization for detecting a specific operation type +template +bool containsAnyOperations(Operation *op) { + Operation *foundOp = walkAndDetect(op, OpEqual()); + return foundOp != nullptr; +} + /// Apply isValidType to the operands and results, returning an appropriate /// logical result. LogicalResult validateTypes(Operation *op, IsValidTypeFn isValidType); @@ -61,11 +87,20 @@ LogicalResult walkAndValidateTypes( template bool containsDialects(Operation *op) { Operation *foundOp = walkAndDetect(op, [&](Operation *op) { - return llvm::isa(op->getDialect()); + return DialectEqual()(op->getDialect()); }); return foundOp != nullptr; } +// Returns true if the op contains argument values of the given type. +// NOTE: any_of instead of all_of +bool containsArgumentOfType(Operation *op, TypePredicate predicate); + +template +bool containsArgumentOfType(Operation *op) { + return containsArgumentOfType(op, TypeEqual()); +} + } // namespace heir } // namespace mlir diff --git a/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir b/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir index e62a9eb66f..cdb8422ce1 100644 --- a/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir +++ b/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --mlir-print-local-scope --bgv-to-lattigo %s | FileCheck %s +// RUN: heir-opt --mlir-print-local-scope --bgv-to-lwe --bgv-to-lattigo %s | FileCheck %s !Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> diff --git a/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir b/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir index 00f0086bd5..56a2aa611b 100644 --- a/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir +++ b/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir @@ -71,9 +71,11 @@ module { // CHECK: [[value2:v.*]] := []int64 // CHECK: [[pt1:v.*]] := bgv.NewPlaintext([[param]], [[param]].MaxLevel()) // CHECK: [[pt2:v.*]] := bgv.NewPlaintext([[param]], [[param]].MaxLevel()) - // CHECK: [[encoder]].Encode([[value1]], [[pt1]]) + // CHECK: [[value1Packed:v.*]][i] = int64([[value1]][i % len([[value1]])]) + // CHECK: [[encoder]].Encode([[value1Packed]], [[pt1]]) // CHECK: [[pt3:v.*]] := [[pt1]] - // CHECK: [[encoder]].Encode([[value2]], [[pt2]]) + // CHECK: [[value2Packed:v.*]][i] = int64([[value2]][i % len([[value2]])]) + // CHECK: [[encoder]].Encode([[value2Packed]], [[pt2]]) // CHECK: [[pt4:v.*]] := [[pt2]] // CHECK: [[ct1:v.*]], [[err:.*]] := [[enc]].EncryptNew([[pt3]]) // CHECK: [[ct2:v.*]], [[err:.*]] := [[enc]].EncryptNew([[pt4]]) @@ -81,7 +83,8 @@ module { // CHECK: [[pt5:v.*]] := [[dec]].DecryptNew([[res]]) // CHECK: [[value3:v.*]] := []int64 // CHECK: [[encoder]].Decode([[pt5]], [[value3]]) - // CHECK: [[value4:v.*]] := [[value3]] + // CHECK: [[value3Converted:v.*]][i] = int64([[value3]][i]) + // CHECK: [[value4:v.*]] := [[value3Converted]] func.func @test_basic_emitter() -> () { %param = lattigo.bgv.new_parameters_from_literal {paramsLiteral = #paramsLiteral} : () -> !params %encoder = lattigo.bgv.new_encoder %param : (!params) -> !encoder diff --git a/tests/Examples/lattigo/BUILD b/tests/Examples/lattigo/BUILD index 9bf5d9e540..9bb4c0e511 100644 --- a/tests/Examples/lattigo/BUILD +++ b/tests/Examples/lattigo/BUILD @@ -8,17 +8,19 @@ package(default_applicable_licenses = ["@heir//:license"]) heir_lattigo_lib( name = "binops", heir_opt_flags = [ - "--secretize", - "--mlir-to-secret-arithmetic", - "--secret-insert-mgmt-bgv", - "--secret-distribute-generic", - "--secret-to-bgv=poly-mod-degree=4", - "--bgv-to-lattigo", - "--cse", + "--mlir-to-lattigo-bgv=entry-function=add ciphertext-degree=4", ], mlir_src = "binops.mlir", ) +heir_lattigo_lib( + name = "dot_product_8", + heir_opt_flags = [ + "--mlir-to-lattigo-bgv=entry-function=dot_product ciphertext-degree=8", + ], + mlir_src = "dot_product_8.mlir", +) + # For Google-internal reasons we must separate the go_test rules from the macro # above. go_test( @@ -30,3 +32,13 @@ go_test( "@lattigo//schemes/bgv", ], ) + +go_test( + name = "dot_product_8_test", + srcs = ["dot_product_8_test.go"], + embed = [":dot_product_8"], + deps = [ + "@lattigo//core/rlwe", + "@lattigo//schemes/bgv", + ], +) diff --git a/tests/Examples/lattigo/binops.mlir b/tests/Examples/lattigo/binops.mlir index 69d5e3293a..788841f7a5 100644 --- a/tests/Examples/lattigo/binops.mlir +++ b/tests/Examples/lattigo/binops.mlir @@ -1,6 +1,6 @@ // From https://github.com/google/heir/pull/1182 -func.func @add(%arg0: tensor<4xi16>, %arg1: tensor<4xi16>) -> tensor<4xi16> { +func.func @add(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> tensor<4xi16> { %0 = arith.addi %arg0, %arg1 : tensor<4xi16> %1 = arith.muli %0, %arg1 : tensor<4xi16> %c1 = arith.constant 1 : index diff --git a/tests/Examples/lattigo/binops_test.go b/tests/Examples/lattigo/binops_test.go index c4916c3a6d..e4c2057f93 100644 --- a/tests/Examples/lattigo/binops_test.go +++ b/tests/Examples/lattigo/binops_test.go @@ -33,62 +33,19 @@ func TestBinops(t *testing.T) { evalKeys := rlwe.NewMemEvaluationKeySet(relinKeys, galKey) evaluator := bgv.NewEvaluator(params, evalKeys /*scaleInvariant=*/, false) - T := params.MaxSlots() // Vector of plaintext values // 0, 1, 2, 3 - arg0 := make([]uint64, T) + arg0 := []int16{0, 1, 2, 3} // 1, 2, 3, 4 - arg1 := make([]uint64, T) + arg1 := []int16{1, 2, 3, 4} - expected := make([]uint64, T) - result := make([]uint64, T) + expected := []int16{6, 15, 28, 1} - // Hack until we have packing system: replicate values cycling every 4 - dataSize := 4 - for i := range arg0 { - arg0[i] = uint64(i % dataSize) - arg1[i] = (uint64(i%dataSize) + 1) - expected[i] = (arg0[i] + arg1[i]) * arg1[i] - } - - // Rotate by 1 - tmp := expected[0] - for i := range T - 1 { - expected[i] = expected[i+1] - } - expected[T-1] = tmp - - // Allocates a plaintext at the max level. - // Default rlwe.MetaData: - // - IsBatched = true (slots encoding) - // - Scale = params.DefaultScale() - pt0 := bgv.NewPlaintext(params, params.MaxLevel()) - pt1 := bgv.NewPlaintext(params, params.MaxLevel()) + ct0, ct1 := add__encrypt(evaluator, params, ecd, enc, arg0, arg1) - // Encodes the vector of plaintext values - if err = ecd.Encode(arg0, pt0); err != nil { - panic(err) - } - if err = ecd.Encode(arg1, pt1); err != nil { - panic(err) - } + resultCt := add(evaluator, params, ecd, ct0, ct1) - // Encrypts the vector of plaintext values - var ct0 *rlwe.Ciphertext - var ct1 *rlwe.Ciphertext - if ct0, err = enc.EncryptNew(pt0); err != nil { - panic(err) - } - if ct1, err = enc.EncryptNew(pt1); err != nil { - panic(err) - } - - resultCt := add(evaluator, ct0, ct1) - resultEncoded := dec.DecryptNew(resultCt) - err = ecd.Decode(resultEncoded, result) - if err != nil { - panic(err) - } + result := add__decrypt(evaluator, params, ecd, dec, resultCt) for i := range 4 { if result[i] != expected[i] { diff --git a/tests/Examples/lattigo/dot_product_8.mlir b/tests/Examples/lattigo/dot_product_8.mlir new file mode 100644 index 0000000000..2096b20f63 --- /dev/null +++ b/tests/Examples/lattigo/dot_product_8.mlir @@ -0,0 +1,12 @@ +func.func @dot_product(%arg0: tensor<8xi16> {secret.secret}, %arg1: tensor<8xi16> {secret.secret}) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) { + %1 = tensor.extract %arg0[%arg2] : tensor<8xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<8xi16> + %3 = arith.muli %1, %2 : i16 + %4 = arith.addi %iter, %3 : i16 + affine.yield %4 : i16 + } + return %0 : i16 +} diff --git a/tests/Examples/lattigo/dot_product_8_test.go b/tests/Examples/lattigo/dot_product_8_test.go new file mode 100644 index 0000000000..ea296729e5 --- /dev/null +++ b/tests/Examples/lattigo/dot_product_8_test.go @@ -0,0 +1,58 @@ +package dot_product_8 + +import ( + "testing" + + "github.com/tuneinsight/lattigo/v6/core/rlwe" + "github.com/tuneinsight/lattigo/v6/schemes/bgv" +) + +func TestBinops(t *testing.T) { + var err error + var params bgv.Parameters + + // 128-bit secure parameters enabling depth-7 circuits. + // LogN:14, LogQP: 431. + if params, err = bgv.NewParametersFromLiteral( + bgv.ParametersLiteral{ + LogN: 14, // log2(ring degree) + LogQ: []int{55, 45, 45, 45, 45, 45, 45, 45}, // log2(primes Q) (ciphertext modulus) + LogP: []int{61}, // log2(primes P) (auxiliary modulus) + PlaintextModulus: 0x10001, // log2(scale) + }); err != nil { + panic(err) + } + + kgen := rlwe.NewKeyGenerator(params) + sk := kgen.GenSecretKeyNew() + ecd := bgv.NewEncoder(params) + enc := rlwe.NewEncryptor(params, sk) + dec := rlwe.NewDecryptor(params, sk) + relinKeys := kgen.GenRelinearizationKeyNew(sk) + // 5^7 % (2^14) = 12589 + galKey12589 := kgen.GenGaloisKeyNew(12589, sk) + // 5^4 = 625 + galKey625 := kgen.GenGaloisKeyNew(625, sk) + // 5^2 = 25 + galKey25 := kgen.GenGaloisKeyNew(25, sk) + // 5^1 = 5 + galKey5 := kgen.GenGaloisKeyNew(5, sk) + evalKeys := rlwe.NewMemEvaluationKeySet(relinKeys, galKey5, galKey25, galKey625, galKey12589) + evaluator := bgv.NewEvaluator(params, evalKeys /*scaleInvariant=*/, false) + + // Vector of plaintext values + arg0 := []int16{1, 2, 3, 4, 5, 6, 7, 8} + arg1 := []int16{2, 3, 4, 5, 6, 7, 8, 9} + + expected := int16(240) + + ct0, ct1 := dot_product__encrypt(evaluator, params, ecd, enc, arg0, arg1) + + resultCt := dot_product(evaluator, params, ecd, ct0, ct1) + + result := dot_product__decrypt(evaluator, params, ecd, dec, resultCt) + + if result != expected { + t.Errorf("Decryption error %d != %d", result, expected) + } +} diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 496fabf996..92e39aeee6 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -369,6 +369,13 @@ int main(int argc, char **argv) { "to OpenFHE C++ code.", mlirToOpenFheRLWEPipelineBuilder(mlir::heir::RLWEScheme::bgvScheme)); + PassPipelineRegistration( + "mlir-to-lattigo-bgv", + "Convert a func using standard MLIR dialects to FHE using BGV and " + "export " + "to Lattigo GO code.", + mlirToLattigoRLWEPipelineBuilder(mlir::heir::RLWEScheme::bgvScheme)); + PassPipelineRegistration( "mlir-to-ckks", "Convert a func using standard MLIR dialects to FHE using "