From fd9867b34304001eea8e9a9bb6388f8c0d2824db Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Tue, 19 Mar 2024 11:36:56 -0700 Subject: [PATCH] secret-to-bgv: add lowering patterns for rotate and add hamming distance example PiperOrigin-RevId: 617242491 --- include/Dialect/BGV/IR/BGVOps.td | 25 ++++++++++- lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp | 5 +-- lib/Conversion/SecretToBGV/BUILD | 2 + lib/Conversion/SecretToBGV/SecretToBGV.cpp | 41 +++++++++++-------- .../secret_to_bgv/hamming_distance_1024.mlir | 33 +++++++++++++++ tests/secret_to_bgv/invalid.mlir | 11 +---- 6 files changed, 84 insertions(+), 33 deletions(-) create mode 100644 tests/secret_to_bgv/hamming_distance_1024.mlir diff --git a/include/Dialect/BGV/IR/BGVOps.td b/include/Dialect/BGV/IR/BGVOps.td index fabf54ef27..7184885a89 100644 --- a/include/Dialect/BGV/IR/BGVOps.td +++ b/include/Dialect/BGV/IR/BGVOps.td @@ -102,12 +102,12 @@ def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> { let summary = "Multiplication operation between ciphertext-plaintext."; } -def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { +def BGV_Rotate : BGV_Op<"rotate", [AllTypesMatch<["x", "output"]>]> { let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism."; let arguments = (ins RLWECiphertext:$x, - I64Attr:$offset + SignlessIntegerLike:$offset ); let results = (outs @@ -117,6 +117,27 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { let hasVerifier = 1; } +def BGV_ExtractOp : BGV_Op<"extract", [AllTypesMatch<["x", "output"]>]> { + let summary = "Extract the i-th element of a ciphertext."; + + let description = [{ + While this operation is costly to compute in FHE, we represent it so we can + implement efficient lowerings and folders. + + This op can be implemented as a plaintext multiplication with a one-hot + vector and a rotate. + }]; + + let arguments = (ins + RLWECiphertext:$x, + SignlessIntegerLike:$offset + ); + + let results = (outs + RLWECiphertext:$output + ); +} + def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> { let summary = "Negate the coefficients of the ciphertext."; diff --git a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp index 4857ec746b..51cb60629c 100644 --- a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp +++ b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp @@ -148,9 +148,8 @@ struct ConvertRotateOp : public OpConversionPattern { if (failed(result)) return result; Value cryptoContext = result.value(); - auto offsetValue = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adaptor.getOffset())); + auto offsetValue = + rewriter.create(op.getLoc(), adaptor.getOffset()); rewriter.replaceOp( op, rewriter.create(op.getLoc(), cryptoContext, adaptor.getX(), offsetValue)); diff --git a/lib/Conversion/SecretToBGV/BUILD b/lib/Conversion/SecretToBGV/BUILD index 3f0494ee0e..983bc56999 100644 --- a/lib/Conversion/SecretToBGV/BUILD +++ b/lib/Conversion/SecretToBGV/BUILD @@ -18,11 +18,13 @@ cc_library( "@heir//lib/Dialect/Polynomial/IR:Polynomial", "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@heir//lib/Dialect/Secret/IR:Dialect", + "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) diff --git a/lib/Conversion/SecretToBGV/SecretToBGV.cpp b/lib/Conversion/SecretToBGV/SecretToBGV.cpp index 39c95e26f5..067bad444d 100644 --- a/lib/Conversion/SecretToBGV/SecretToBGV.cpp +++ b/lib/Conversion/SecretToBGV/SecretToBGV.cpp @@ -14,16 +14,18 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/IR/SecretOps.h" #include "include/Dialect/Secret/IR/SecretTypes.h" +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" #include "lib/Conversion/Utils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::heir { @@ -64,12 +66,14 @@ class SecretToBGVTypeConverter : public TypeConverter { // Convert secret types to BGV ciphertext types addConversion([ctx, this](secret::SecretType type) -> Type { - RankedTensorType tensorTy = cast(type.getValueType()); + int bitWidth = + llvm::TypeSwitch(type.getValueType()) + .Case( + [&](auto ty) -> int { return ty.getElementTypeBitWidth(); }) + .Case([&](auto ty) -> int { return ty.getWidth(); }); return lwe::RLWECiphertextType::get( ctx, - lwe::PolynomialEvaluationEncodingAttr::get( - ctx, tensorTy.getElementTypeBitWidth(), - tensorTy.getElementTypeBitWidth()), + lwe::PolynomialEvaluationEncodingAttr::get(ctx, bitWidth, bitWidth), lwe::RLWEParamsAttr::get(ctx, 2, ring_)); }); @@ -108,9 +112,7 @@ class SecretGenericOpConversion inputs.push_back( adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); } else { - return rewriter.notifyMatchFailure( - op->getLoc(), - "Plaintext-ciphertext operations are not yet supported."); + inputs.push_back(operand.get()); } } @@ -158,7 +160,7 @@ struct SecretToBGV : public impl::SecretToBGVBase { for (auto value : op->getOperands()) { if (auto secretTy = dyn_cast(value.getType())) { auto tensorTy = dyn_cast(secretTy.getValueType()); - if (!tensorTy || + if (tensorTy && tensorTy.getShape() != ArrayRef{rlweRing.value().getIdeal().getDegree()}) { return WalkResult::interrupt(); @@ -169,7 +171,7 @@ struct SecretToBGV : public impl::SecretToBGVBase { }); if (compatibleTensors.wasInterrupted()) { module->emitError( - "expected secret types to be tensors with dimension " + "expected batched secret types to be tensors with dimension " "matching ring parameter"); return signalPassFailure(); } @@ -183,6 +185,9 @@ struct SecretToBGV : public impl::SecretToBGVBase { addStructuralConversionPatterns(typeConverter, patterns, target); patterns.add, + SecretGenericOpConversion, + SecretGenericOpConversion, + SecretGenericOpConversion, SecretGenericOpMulConversion>(typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/tests/secret_to_bgv/hamming_distance_1024.mlir b/tests/secret_to_bgv/hamming_distance_1024.mlir new file mode 100644 index 0000000000..a7deffb009 --- /dev/null +++ b/tests/secret_to_bgv/hamming_distance_1024.mlir @@ -0,0 +1,33 @@ +// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic \ +// RUN: --canonicalize --cse --heir-simd-vectorizer \ +// RUN: --secret-distribute-generic --secret-to-bgv \ +// RUN: %s | FileCheck %s + +// CHECK-LABEL: @hamming +// CHECK: bgv.sub +// CHECK-NEXT: bgv.mul +// CHECK-NEXT: bgv.relinearize +// CHECK-NEXT: bgv.rotate +// CHECK-NEXT: bgv.add +// CHECK-NEXT: bgv.rotate +// CHECK-NEXT: bgv.add +// CHECK-NEXT: bgv.add +// CHECK-NEXT: bgv.extract +// CHECK-NEXT: return + +func.func @hamming(%arg0: tensor<1024xi16>, %arg1: tensor<1024xi16>) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 1024 iter_args(%arg6 = %c0_si16) -> i16 { + %1 = tensor.extract %arg0[%arg2] : tensor<1024xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<1024xi16> + %3 = arith.subi %1, %2 : i16 + %4 = tensor.extract %arg0[%arg2] : tensor<1024xi16> + %5 = tensor.extract %arg1[%arg2] : tensor<1024xi16> + %6 = arith.subi %4, %5 : i16 + %7 = arith.muli %3, %6 : i16 + %8 = arith.addi %arg6, %7 : i16 + affine.yield %8 : i16 + } + return %0 : i16 +} diff --git a/tests/secret_to_bgv/invalid.mlir b/tests/secret_to_bgv/invalid.mlir index bcd52e69b2..6fb3f136a5 100644 --- a/tests/secret_to_bgv/invalid.mlir +++ b/tests/secret_to_bgv/invalid.mlir @@ -2,16 +2,7 @@ // Tests invalid secret types -// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} -module { - func.func @test_not_tensor(%arg0 : !secret.secret) -> (!secret.secret) { - return %arg0 : !secret.secret - } -} - -// ----- - -// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} +// expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}} module { func.func @test_invalid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { return %arg0 : !secret.secret>