Skip to content

Commit

Permalink
secret-to-bgv: add lowering patterns for rotate and add hamming dista…
Browse files Browse the repository at this point in the history
…nce example

PiperOrigin-RevId: 617242491
  • Loading branch information
asraa authored and copybara-github committed Mar 28, 2024
1 parent 8cb01ed commit fd9867b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 33 deletions.
25 changes: 23 additions & 2 deletions include/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> {
let summary = "Multiplication operation between ciphertext-plaintext.";
}

def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
def BGV_Rotate : BGV_Op<"rotate", [AllTypesMatch<["x", "output"]>]> {
let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism.";

let arguments = (ins
RLWECiphertext:$x,
I64Attr:$offset
SignlessIntegerLike:$offset
);

let results = (outs
Expand All @@ -117,6 +117,27 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let hasVerifier = 1;
}

def BGV_ExtractOp : BGV_Op<"extract", [AllTypesMatch<["x", "output"]>]> {
let summary = "Extract the i-th element of a ciphertext.";

let description = [{
While this operation is costly to compute in FHE, we represent it so we can
implement efficient lowerings and folders.

This op can be implemented as a plaintext multiplication with a one-hot
vector and a rotate.
}];

let arguments = (ins
RLWECiphertext:$x,
SignlessIntegerLike:$offset
);

let results = (outs
RLWECiphertext:$output
);
}

def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
let summary = "Negate the coefficients of the ciphertext.";

Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ struct ConvertRotateOp : public OpConversionPattern<Rotate> {
if (failed(result)) return result;

Value cryptoContext = result.value();
auto offsetValue = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adaptor.getOffset()));
auto offsetValue =
rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getOffset());
rewriter.replaceOp(
op, rewriter.create<openfhe::RotOp>(op.getLoc(), cryptoContext,
adaptor.getX(), offsetValue));
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/SecretToBGV/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ cc_library(
"@heir//lib/Dialect/Polynomial/IR:Polynomial",
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
41 changes: 23 additions & 18 deletions lib/Conversion/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir {
Expand Down Expand Up @@ -64,12 +66,14 @@ class SecretToBGVTypeConverter : public TypeConverter {

// Convert secret types to BGV ciphertext types
addConversion([ctx, this](secret::SecretType type) -> Type {
RankedTensorType tensorTy = cast<RankedTensorType>(type.getValueType());
int bitWidth =
llvm::TypeSwitch<Type, int>(type.getValueType())
.Case<RankedTensorType>(
[&](auto ty) -> int { return ty.getElementTypeBitWidth(); })
.Case<IntegerType>([&](auto ty) -> int { return ty.getWidth(); });
return lwe::RLWECiphertextType::get(
ctx,
lwe::PolynomialEvaluationEncodingAttr::get(
ctx, tensorTy.getElementTypeBitWidth(),
tensorTy.getElementTypeBitWidth()),
lwe::PolynomialEvaluationEncodingAttr::get(ctx, bitWidth, bitWidth),
lwe::RLWEParamsAttr::get(ctx, 2, ring_));
});

Expand Down Expand Up @@ -108,9 +112,7 @@ class SecretGenericOpConversion
inputs.push_back(
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]);
} else {
return rewriter.notifyMatchFailure(
op->getLoc(),
"Plaintext-ciphertext operations are not yet supported.");
inputs.push_back(operand.get());
}
}

Expand Down Expand Up @@ -158,7 +160,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
for (auto value : op->getOperands()) {
if (auto secretTy = dyn_cast<secret::SecretType>(value.getType())) {
auto tensorTy = dyn_cast<RankedTensorType>(secretTy.getValueType());
if (!tensorTy ||
if (tensorTy &&
tensorTy.getShape() !=
ArrayRef<int64_t>{rlweRing.value().getIdeal().getDegree()}) {
return WalkResult::interrupt();
Expand All @@ -169,7 +171,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
});
if (compatibleTensors.wasInterrupted()) {
module->emitError(
"expected secret types to be tensors with dimension "
"expected batched secret types to be tensors with dimension "
"matching ring parameter");
return signalPassFailure();
}
Expand All @@ -183,6 +185,9 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {

addStructuralConversionPatterns(typeConverter, patterns, target);
patterns.add<SecretGenericOpConversion<arith::AddIOp, bgv::AddOp>,
SecretGenericOpConversion<arith::SubIOp, bgv::SubOp>,
SecretGenericOpConversion<tensor::ExtractOp, bgv::ExtractOp>,
SecretGenericOpConversion<tensor_ext::RotateOp, bgv::Rotate>,
SecretGenericOpMulConversion>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
Expand Down
33 changes: 33 additions & 0 deletions tests/secret_to_bgv/hamming_distance_1024.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic \
// RUN: --canonicalize --cse --heir-simd-vectorizer \
// RUN: --secret-distribute-generic --secret-to-bgv \
// RUN: %s | FileCheck %s

// CHECK-LABEL: @hamming
// CHECK: bgv.sub
// CHECK-NEXT: bgv.mul
// CHECK-NEXT: bgv.relinearize
// CHECK-NEXT: bgv.rotate
// CHECK-NEXT: bgv.add
// CHECK-NEXT: bgv.rotate
// CHECK-NEXT: bgv.add
// CHECK-NEXT: bgv.add
// CHECK-NEXT: bgv.extract
// CHECK-NEXT: return

func.func @hamming(%arg0: tensor<1024xi16>, %arg1: tensor<1024xi16>) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
%0 = affine.for %arg2 = 0 to 1024 iter_args(%arg6 = %c0_si16) -> i16 {
%1 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%2 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%3 = arith.subi %1, %2 : i16
%4 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%5 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%6 = arith.subi %4, %5 : i16
%7 = arith.muli %3, %6 : i16
%8 = arith.addi %arg6, %7 : i16
affine.yield %8 : i16
}
return %0 : i16
}
11 changes: 1 addition & 10 deletions tests/secret_to_bgv/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

// Tests invalid secret types

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_not_tensor(%arg0 : !secret.secret<i1>) -> (!secret.secret<i1>) {
return %arg0 : !secret.secret<i1>
}
}

// -----

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
// expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_invalid_dimension(%arg0 : !secret.secret<tensor<1000xi1>>) -> (!secret.secret<tensor<1000xi1>>) {
return %arg0 : !secret.secret<tensor<1000xi1>>
Expand Down

0 comments on commit fd9867b

Please sign in to comment.