Skip to content

Commit

Permalink
Merge pull request #1126 from ZenithalHourlyRate:bgv-modreduce
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705133327
  • Loading branch information
copybara-github committed Dec 11, 2024
2 parents 81beb32 + 6ef5d0b commit 854a878
Show file tree
Hide file tree
Showing 55 changed files with 2,215 additions and 155 deletions.
144 changes: 86 additions & 58 deletions docs/content/en/docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,46 +165,63 @@ bazel run //tools:heir-opt -- \
$PWD/tests/Examples/openfhe/dot_product_8.mlir > output.mlir
```

This produces a file in the `openfhe` exit dialect (part of HEIR). The raw
output is rather verbose, and an abbreviated version is shown below.
This produces a file in the `openfhe` exit dialect (part of HEIR).

```mlir
!tensor_ct = !lwe.rlwe_ciphertext<..., underlying_type = tensor<8xi16>>
!scalar_ct = !lwe.rlwe_ciphertext<..., underlying_type = i16>
!mul_ct = !lwe.rlwe_ciphertext<..., underlying_type = tensor<8xi16>>
!tensor_plaintext = lwe.rlwe_plaintext<..., underlying_type = tensor<8xi16>>
!Z1005037682689_i64_ = !mod_arith.int<1005037682689 : i64>
!Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64>
!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64>
#polynomial_evaluation_encoding = #lwe.polynomial_evaluation_encoding<cleartext_start = 16, cleartext_bitwidth = 16>
!rns_L0_ = !rns.rns<!Z1095233372161_i64_>
!rns_L1_ = !rns.rns<!Z1095233372161_i64_, !Z1032955396097_i64_>
!rns_L2_ = !rns.rns<!Z1095233372161_i64_, !Z1032955396097_i64_, !Z1005037682689_i64_>
#ring_rns_L0_1_x8_ = #polynomial.ring<coefficientType = !rns_L0_, polynomialModulus = <1 + x**8>>
#ring_rns_L1_1_x8_ = #polynomial.ring<coefficientType = !rns_L1_, polynomialModulus = <1 + x**8>>
#ring_rns_L2_1_x8_ = #polynomial.ring<coefficientType = !rns_L2_, polynomialModulus = <1 + x**8>>
!rlwe_pt_L0_ = !lwe.rlwe_plaintext<encoding = #polynomial_evaluation_encoding, ring = #ring_rns_L0_1_x8_, underlying_type = i16>
!rlwe_pt_L1_ = !lwe.rlwe_plaintext<encoding = #polynomial_evaluation_encoding, ring = #ring_rns_L1_1_x8_, underlying_type = tensor<8xi16>>
!rlwe_pt_L2_ = !lwe.rlwe_plaintext<encoding = #polynomial_evaluation_encoding, ring = #ring_rns_L2_1_x8_, underlying_type = tensor<8xi16>>
#rlwe_params_L0_ = #lwe.rlwe_params<ring = #ring_rns_L0_1_x8_>
#rlwe_params_L1_ = #lwe.rlwe_params<ring = #ring_rns_L1_1_x8_>
#rlwe_params_L2_ = #lwe.rlwe_params<ring = #ring_rns_L2_1_x8_>
#rlwe_params_L2_D3_ = #lwe.rlwe_params<dimension = 3, ring = #ring_rns_L2_1_x8_>
!rlwe_ct_L0_ = !lwe.rlwe_ciphertext<encoding = #polynomial_evaluation_encoding, rlwe_params = #rlwe_params_L0_, underlying_type = i16>
!rlwe_ct_L1_ = !lwe.rlwe_ciphertext<encoding = #polynomial_evaluation_encoding, rlwe_params = #rlwe_params_L1_, underlying_type = tensor<8xi16>>
!rlwe_ct_L1_1 = !lwe.rlwe_ciphertext<encoding = #polynomial_evaluation_encoding, rlwe_params = #rlwe_params_L1_, underlying_type = i16>
!rlwe_ct_L2_ = !lwe.rlwe_ciphertext<encoding = #polynomial_evaluation_encoding, rlwe_params = #rlwe_params_L2_, underlying_type = tensor<8xi16>>
!rlwe_ct_L2_D3_ = !lwe.rlwe_ciphertext<encoding = #polynomial_evaluation_encoding, rlwe_params = #rlwe_params_L2_D3_, underlying_type = tensor<8xi16>>
module {
func.func @dot_product(%arg0: !openfhe.crypto_context, %arg1: !tensor_ct, %arg2: !tensor_ct) -> !scalar_ct {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c7 = arith.constant 7 : index
%0 = openfhe.mul_no_relin %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !tensor_ct, !tensor_ct) -> !mul_ct
%1 = openfhe.relin %arg0, %0 : (!openfhe.crypto_context, !mul_ct) -> !tensor_ct
%2 = arith.index_cast %c4 : index to i64
%3 = openfhe.rot %arg0, %1, %2 : (!openfhe.crypto_context, !tensor_ct, i64) -> !tensor_ct
%4 = openfhe.add %arg0, %1, %3 : (!openfhe.crypto_context, !tensor_ct, !tensor_ct) -> !tensor_ct
%5 = arith.index_cast %c2 : index to i64
%6 = openfhe.rot %arg0, %4, %5 : (!openfhe.crypto_context, !tensor_ct, i64) -> !tensor_ct
%7 = openfhe.add %arg0, %4, %6 : (!openfhe.crypto_context, !tensor_ct, !tensor_ct) -> !tensor_ct
%8 = arith.index_cast %c1 : index to i64
%9 = openfhe.rot %arg0, %7, %8 : (!openfhe.crypto_context, !tensor_ct, i64) -> !tensor_ct
%10 = openfhe.add %arg0, %7, %9 : (!openfhe.crypto_context, !tensor_ct, !tensor_ct) -> !tensor_ct
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi16>
%11 = lwe.rlwe_encode %cst {encoding = #lwe.polynomial_evaluation_encoding<cleartext_start = 16, cleartext_bitwidth = 16>, ring = #_polynomial.ring<cmod=463187969, ideal=#_polynomial.polynomial<1 + x**8>>} : tensor<8xi16> -> !tensor_plaintext
%12 = openfhe.mul_plain %arg0, %10, %11 : (!openfhe.crypto_context, !tensor_ct, !tensor_plaintext) -> !tensor_ct
%13 = arith.index_cast %c7 : index to i64
%14 = openfhe.rot %arg0, %12, %13 : (!openfhe.crypto_context, !tensor_ct, i64) -> !tensor_ct
%15 = lwe.reinterpret_underlying_type %14 : !tensor_ct to !scalar_ct
return %15 : !scalar_ct
func.func @dot_product(%arg0: !openfhe.crypto_context, %arg1: !rlwe_ct_L2_, %arg2: !rlwe_ct_L2_) -> !rlwe_ct_L0_ {
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64>
%0 = openfhe.mul_no_relin %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !rlwe_ct_L2_, !rlwe_ct_L2_) -> !rlwe_ct_L2_D3_
%1 = openfhe.relin %arg0, %0 : (!openfhe.crypto_context, !rlwe_ct_L2_D3_) -> !rlwe_ct_L2_
%2 = openfhe.rot %arg0, %1 {index = 4 : index} : (!openfhe.crypto_context, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%3 = openfhe.add %arg0, %1, %2 : (!openfhe.crypto_context, !rlwe_ct_L2_, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%4 = openfhe.rot %arg0, %3 {index = 2 : index} : (!openfhe.crypto_context, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%5 = openfhe.add %arg0, %3, %4 : (!openfhe.crypto_context, !rlwe_ct_L2_, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%6 = openfhe.rot %arg0, %5 {index = 1 : index} : (!openfhe.crypto_context, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%7 = openfhe.add %arg0, %5, %6 : (!openfhe.crypto_context, !rlwe_ct_L2_, !rlwe_ct_L2_) -> !rlwe_ct_L2_
%8 = openfhe.mod_reduce %arg0, %7 : (!openfhe.crypto_context, !rlwe_ct_L2_) -> !rlwe_ct_L1_
%9 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<8xi64>) -> !rlwe_pt_L1_
%10 = openfhe.mul_plain %arg0, %8, %9 : (!openfhe.crypto_context, !rlwe_ct_L1_, !rlwe_pt_L1_) -> !rlwe_ct_L1_
%11 = openfhe.rot %arg0, %10 {index = 7 : index} : (!openfhe.crypto_context, !rlwe_ct_L1_) -> !rlwe_ct_L1_
%12 = lwe.reinterpret_underlying_type %11 : !rlwe_ct_L1_ to !rlwe_ct_L1_1
%13 = openfhe.mod_reduce %arg0, %12 : (!openfhe.crypto_context, !rlwe_ct_L1_1) -> !rlwe_ct_L0_
return %13 : !rlwe_ct_L0_
}
func.func @dot_product__encrypt__arg0(%arg0: !openfhe.crypto_context, %arg1: tensor<8xi16>, %arg2: !openfhe.public_key) -> !tensor_ct
func.func @dot_product__encrypt__arg0(%arg0: !openfhe.crypto_context, %arg1: tensor<8xi16>, %arg2: !openfhe.public_key) -> !rlwe_ct_L2_ {
...
}
func.func @dot_product__encrypt__arg1(%arg0: !openfhe.crypto_context, %arg1: tensor<8xi16>, %arg2: !openfhe.public_key) -> !tensor_ct
func.func @dot_product__encrypt__arg1(%arg0: !openfhe.crypto_context, %arg1: tensor<8xi16>, %arg2: !openfhe.public_key) -> !rlwe_ct_L2_ {
...
}
func.func @dot_product__decrypt__result0(%arg0: !openfhe.crypto_context, %arg1: !scalar_ct, %arg2: !openfhe.private_key) -> i16 {
func.func @dot_product__decrypt__result0(%arg0: !openfhe.crypto_context, %arg1: !rlwe_ct_L0_, %arg2: !openfhe.private_key) -> i16 {
...
}
func.func @dot_product__generate_crypto_context() -> !openfhe.crypto_context {
...
}
func.func @dot_product__configure_crypto_context(%arg0: !openfhe.crypto_context, %arg1: !openfhe.private_key) -> !openfhe.crypto_context {
...
}
}
Expand All @@ -226,16 +243,20 @@ The results:

using namespace lbcrypto;
using CiphertextT = ConstCiphertext<DCRTPoly>;
using CCParamsT = CCParams<CryptoContextBGVRNS>;
using CryptoContextT = CryptoContext<DCRTPoly>;
using EvalKeyT = EvalKey<DCRTPoly>;
using PlaintextT = Plaintext;
using PrivateKeyT = PrivateKey<DCRTPoly>;
using PublicKeyT = PublicKey<DCRTPoly>;

CiphertextT dot_product(CryptoContextT v0, CiphertextT v1, CiphertextT v2);
CiphertextT dot_product__encrypt__arg0(CryptoContextT v24, std::vector<int16_t> v25, PublicKeyT v26);
CiphertextT dot_product__encrypt__arg1(CryptoContextT v29, std::vector<int16_t> v30, PublicKeyT v31);
int16_t dot_product__decrypt__result0(CryptoContextT v34, CiphertextT v35, PrivateKeyT v36);
CiphertextT dot_product__encrypt__arg0(CryptoContextT v18, std::vector<int16_t> v19, PublicKeyT v20);
CiphertextT dot_product__encrypt__arg1(CryptoContextT v24, std::vector<int16_t> v25, PublicKeyT v26);
int16_t dot_product__decrypt__result0(CryptoContextT v30, CiphertextT v31, PrivateKeyT v32);
CryptoContextT dot_product__generate_crypto_context();
CryptoContextT dot_product__configure_crypto_context(CryptoContextT v37, PrivateKeyT v38);


// heir_output.cpp
#include "src/pke/include/openfhe.h" // from @openfhe
Expand All @@ -249,29 +270,29 @@ using PrivateKeyT = PrivateKey<DCRTPoly>;
using PublicKeyT = PublicKey<DCRTPoly>;

CiphertextT dot_product(CryptoContextT v0, CiphertextT v1, CiphertextT v2) {
size_t v3 = 1;
size_t v4 = 2;
size_t v5 = 4;
size_t v6 = 7;
const auto& v7 = v0->EvalMultNoRelin(v1, v2);
const auto& v8 = v0->Relinearize(v7);
int64_t v9 = static_cast<int64_t>(v5);
const auto& v10 = v0->EvalRotate(v8, v9);
const auto& v11 = v0->EvalAdd(v8, v10);
int64_t v12 = static_cast<int64_t>(v4);
const auto& v13 = v0->EvalRotate(v11, v12);
const auto& v14 = v0->EvalAdd(v11, v13);
int64_t v15 = static_cast<int64_t>(v3);
const auto& v16 = v0->EvalRotate(v14, v15);
const auto& v17 = v0->EvalAdd(v14, v16);
std::vector<int16_t> v18 = {0, 0, 0, 0, 0, 0, 0, 1};
std::vector<int64_t> v18_cast(std::begin(v18), std::end(v18));
const auto& v19 = v0->MakePackedPlaintext(v18_cast);
const auto& v20 = v0->EvalMult(v17, v19);
int64_t v21 = static_cast<int64_t>(v6);
const auto& v22 = v0->EvalRotate(v20, v21);
const auto& v23 = v22;
return v23;
std::vector<int64_t> v3 = {0, 0, 0, 0, 0, 0, 0, 1};
const auto& v4 = v0->EvalMultNoRelin(v1, v2);
const auto& v5 = v0->Relinearize(v4);
const auto& v6 = v0->EvalRotate(v5, 4);
const auto& v7 = v0->EvalAdd(v5, v6);
const auto& v8 = v0->EvalRotate(v7, 2);
const auto& v9 = v0->EvalAdd(v7, v8);
const auto& v10 = v0->EvalRotate(v9, 1);
const auto& v11 = v0->EvalAdd(v9, v10);
const auto& v12 = v0->ModReduce(v11);
auto v3_filled_n = v0->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2;
auto v3_filled = v3;
v3_filled.clear();
v3_filled.reserve(v3_filled_n);
for (auto i = 0; i < v3_filled_n; ++i) {
v3_filled.push_back(v3[i % v3.size()]);
}
const auto& v13 = v0->MakePackedPlaintext(v3_filled);
const auto& v14 = v0->EvalMult(v12, v13);
const auto& v15 = v0->EvalRotate(v14, 7);
const auto& v16 = v15;
const auto& v17 = v0->ModReduce(v16);
return v17;
}
CiphertextT dot_product__encrypt__arg0(CryptoContextT v24, std::vector<int16_t> v25, PublicKeyT v26) {
...
Expand All @@ -282,6 +303,13 @@ CiphertextT dot_product__encrypt__arg1(CryptoContextT v29, std::vector<int16_t>
int16_t dot_product__decrypt__result0(CryptoContextT v34, CiphertextT v35, PrivateKeyT v36) {
...
}
CryptoContextT dot_product__generate_crypto_context() {
...
}
CryptoContextT dot_product__configure_crypto_context(CryptoContextT v37, PrivateKeyT v38) {
...
}

```
At this point we can compile the program as we would a normal OpenFHE program.
Expand Down
21 changes: 21 additions & 0 deletions lib/Analysis/DimensionAnalysis/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "DimensionAnalysis",
srcs = ["DimensionAnalysis.cpp"],
hdrs = ["DimensionAnalysis.h"],
deps = [
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)
127 changes: 127 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

LogicalResult DimensionAnalysis::visitOperation(
Operation *op, ArrayRef<const DimensionLattice *> operands,
ArrayRef<DimensionLattice *> results) {
auto propagate = [&](Value value, const DimensionState &state) {
auto *lattice = getLatticeElement(value);
ChangeResult changed = lattice->join(state);
propagateIfChanged(lattice, changed);
};

auto ensureSecretness = [&](Operation *op, Value value) -> bool {
// create dependency on SecretnessAnalysis
auto *lattice =
getOrCreateFor<SecretnessLattice>(getProgramPointAfter(op), value);
if (!lattice->getValue().isInitialized()) {
return false;
}
return lattice->getValue().getSecretness();
};

llvm::TypeSwitch<Operation &>(*op)
.Case<secret::GenericOp>([&](auto genericOp) {
Block *body = genericOp.getBody();
for (auto i = 0; i != body->getNumArguments(); ++i) {
auto blockArg = body->getArgument(i);
propagate(blockArg, DimensionState(2));
}
})
.Case<mgmt::RelinearizeOp>([&](auto relinearizeOp) {
// implicitly ensure that the operand is secret
propagate(relinearizeOp.getResult(), DimensionState(2));
})
.Default([&](auto &op) {
if (op.getNumResults() == 0) {
return;
}

// condition on result secretness
auto secretness = ensureSecretness(&op, op.getResult(0));
if (!secretness) {
return;
}

auto isMul = false;
if (isa<arith::MulIOp, arith::MulFOp>(op)) {
isMul = true;
}

auto dimensionResult = 0;
auto operandSecretNum = 0;
for (const auto *operand : operands) {
auto secretness = ensureSecretness(&op, operand->getAnchor());
// pt/ct default
auto dimension = 2;
bool operandIsSecret = false;
if (secretness) {
if (!operand->getValue().isInitialized()) {
return;
}
// ct
operandIsSecret = true;
operandSecretNum += 1;
dimension = operand->getValue().getDimension();
}

if (isMul && operandIsSecret) {
dimensionResult += dimension;
} else {
dimensionResult = std::max(dimensionResult, dimension);
}
}
// tensor product
if (isMul && operandSecretNum == 2) {
dimensionResult -= 1;
}

for (auto result : op.getResults()) {
propagate(result, DimensionState(dimensionResult));
}
});
return success();
}

void annotateDimension(Operation *top, DataFlowSolver *solver) {
auto getIntegerAttr = [&](int dimension) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), dimension);
};

auto getDimension = [&](Value value) {
return solver->lookupState<DimensionLattice>(value)
->getValue()
.getDimension();
};

top->walk<WalkOrder::PreOrder>([&](secret::GenericOp genericOp) {
for (auto blockArg : genericOp.getBody()->getArguments()) {
genericOp.setArgAttr(blockArg.getArgNumber(), "dimension",
getIntegerAttr(getDimension(blockArg)));
}

genericOp.getBody()->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->getNumResults() == 0) {
return;
}
op->setAttr("dimension", getIntegerAttr(getDimension(op->getResult(0))));
});
});
}

} // namespace heir
} // namespace mlir
Loading

0 comments on commit 854a878

Please sign in to comment.