Skip to content

Commit

Permalink
relax constraints on ctxt dimension in add/sub/mul
Browse files Browse the repository at this point in the history
Removes constraints that ctxt dimension must match from
`bgv.add`,`bgv.sub`,`bgv.mul`,
`ckks.add`,`ckks.sub`,`ckks.mul`,
`lwe.radd`,`lwe.rsub`,`lwe.rmul`,
`openfhe.add`,`openfhe.sub`,`openfhe.mul_no_relin`.
  • Loading branch information
AlexanderViand-Intel committed Jan 30, 2025
1 parent ad59e8f commit 7983559
Show file tree
Hide file tree
Showing 22 changed files with 148 additions and 50 deletions.
12 changes: 12 additions & 0 deletions lib/Dialect/BGV/IR/BGVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ LogicalResult ModulusSwitchOp::verify() {
return verifyModulusSwitchOrRescaleOp(this);
}

LogicalResult AddOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, AddOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

LogicalResult SubOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, SubOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

LogicalResult MulOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, MulOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
Expand Down
10 changes: 3 additions & 7 deletions lib/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BGV_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
);
}

def BGV_AddOp : BGV_Op<"add", [Pure, Commutative, SameOperandsAndResultType]> {
def BGV_AddOp : BGV_Op<"add", [Pure, Commutative, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Addition operation between ciphertexts.";

let arguments = (ins
Expand All @@ -60,15 +60,13 @@ def BGV_AddOp : BGV_Op<"add", [Pure, Commutative, SameOperandsAndResultType]> {
let results = (outs
NewLWECiphertext:$output
);

let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

def BGV_AddPlainOp : BGV_CiphertextPlaintextOp<"add_plain"> {
let summary = "Addition operation between ciphertext-plaintext.";
}

def BGV_SubOp : BGV_Op<"sub", [Pure, SameOperandsAndResultType]> {
def BGV_SubOp : BGV_Op<"sub", [Pure, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Subtraction operation between ciphertexts.";

let arguments = (ins
Expand All @@ -79,15 +77,13 @@ def BGV_SubOp : BGV_Op<"sub", [Pure, SameOperandsAndResultType]> {
let results = (outs
NewLWECiphertext:$output
);

let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

def BGV_SubPlainOp : BGV_CiphertextPlaintextOp<"sub_plain"> {
let summary = "Subtraction operation between ciphertext-plaintext.";
}

def BGV_MulOp : BGV_Op<"mul", [Pure, Commutative, SameOperandsAndResultRings, SameTypeOperands, InferTypeOpAdaptor]> {
def BGV_MulOp : BGV_Op<"mul", [Pure, Commutative, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Multiplication operation between ciphertexts.";

let arguments = (ins
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/CKKS/IR/CKKSDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ LogicalResult RescaleOp::verify() {
return verifyModulusSwitchOrRescaleOp(this);
}

LogicalResult AddOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, AddOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

LogicalResult SubOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, SubOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

LogicalResult MulOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, MulOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
Expand Down
10 changes: 3 additions & 7 deletions lib/Dialect/CKKS/IR/CKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CKKS_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
);
}

def CKKS_AddOp : CKKS_Op<"add", [Pure, Commutative, SameOperandsAndResultType]> {
def CKKS_AddOp : CKKS_Op<"add", [Pure, Commutative, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Addition operation between ciphertexts.";

let arguments = (ins
Expand All @@ -59,15 +59,13 @@ def CKKS_AddOp : CKKS_Op<"add", [Pure, Commutative, SameOperandsAndResultType]>
let results = (outs
NewLWECiphertext:$output
);

let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

def CKKS_AddPlainOp : CKKS_CiphertextPlaintextOp<"add_plain"> {
let summary = "Addition operation between ciphertext-plaintext.";
}

def CKKS_SubOp : CKKS_Op<"sub", [SameOperandsAndResultType]> {
def CKKS_SubOp : CKKS_Op<"sub", [Pure, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Subtraction operation between ciphertexts.";

let arguments = (ins
Expand All @@ -78,15 +76,13 @@ def CKKS_SubOp : CKKS_Op<"sub", [SameOperandsAndResultType]> {
let results = (outs
NewLWECiphertext:$output
);

let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

def CKKS_SubPlainOp : CKKS_CiphertextPlaintextOp<"sub_plain"> {
let summary = "Subtraction operation between ciphertext-plaintext.";
}

def CKKS_MulOp : CKKS_Op<"mul", [Pure, Commutative, SameOperandsAndResultRings, SameTypeOperands, InferTypeOpAdaptor]> {
def CKKS_MulOp : CKKS_Op<"mul", [Pure, Commutative, SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Multiplication operation between ciphertexts.";

let arguments = (ins
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/FHEHelpers.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_DIALECT_FHEHELPERS_H_
#define LIB_DIALECT_FHEHELPERS_H_

#include <algorithm>
#include <cstddef>

#include "lib/Dialect/LWE/IR/LWEAttributes.h"
Expand Down Expand Up @@ -125,6 +126,23 @@ LogicalResult verifyModulusSwitchOrRescaleOp(Op* op) {
return success();
}

template <typename Adaptor>
LogicalResult inferAddOpReturnTypes(
MLIRContext* ctx, Adaptor adaptor,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto x = cast<lwe::NewLWECiphertextType>(adaptor.getLhs().getType());
auto y = cast<lwe::NewLWECiphertextType>(adaptor.getRhs().getType());
auto newDim = std::max(x.getCiphertextSpace().getSize(),
y.getCiphertextSpace().getSize());
inferredReturnTypes.push_back(lwe::NewLWECiphertextType::get(
ctx, x.getApplicationData(), x.getPlaintextSpace(),
lwe::CiphertextSpaceAttr::get(ctx, x.getCiphertextSpace().getRing(),
x.getCiphertextSpace().getEncryptionType(),
newDim),
x.getKey(), x.getModulusChain()));
return success();
}

template <typename Adaptor>
LogicalResult inferMulOpReturnTypes(
MLIRContext* ctx, Adaptor adaptor,
Expand Down
34 changes: 34 additions & 0 deletions lib/Dialect/LWE/IR/LWEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,40 @@ LogicalResult RMulOp::verify() {
return success();
}

LogicalResult RAddOp::inferReturnTypes(
MLIRContext* ctx, std::optional<Location>, RAddOp::Adaptor adaptor,
SmallVectorImpl<Type>& inferredReturnTypes) {
// NOT using FHEHelpers.h here because cyclic dependency
auto x = cast<lwe::NewLWECiphertextType>(adaptor.getLhs().getType());
auto y = cast<lwe::NewLWECiphertextType>(adaptor.getRhs().getType());
auto newDim = std::max(x.getCiphertextSpace().getSize(),
y.getCiphertextSpace().getSize());
inferredReturnTypes.push_back(lwe::NewLWECiphertextType::get(
ctx, x.getApplicationData(), x.getPlaintextSpace(),
lwe::CiphertextSpaceAttr::get(ctx, x.getCiphertextSpace().getRing(),
x.getCiphertextSpace().getEncryptionType(),
newDim),
x.getKey(), x.getModulusChain()));
return success();
}

LogicalResult RSubOp::inferReturnTypes(
MLIRContext* ctx, std::optional<Location>, RSubOp::Adaptor adaptor,
SmallVectorImpl<Type>& inferredReturnTypes) {
// NOT using FHEHelpers.h here because cyclic dependency
auto x = cast<lwe::NewLWECiphertextType>(adaptor.getLhs().getType());
auto y = cast<lwe::NewLWECiphertextType>(adaptor.getRhs().getType());
auto newDim = std::max(x.getCiphertextSpace().getSize(),
y.getCiphertextSpace().getSize());
inferredReturnTypes.push_back(lwe::NewLWECiphertextType::get(
ctx, x.getApplicationData(), x.getPlaintextSpace(),
lwe::CiphertextSpaceAttr::get(ctx, x.getCiphertextSpace().getRing(),
x.getCiphertextSpace().getEncryptionType(),
newDim),
x.getKey(), x.getModulusChain()));
return success();
}

LogicalResult RMulOp::inferReturnTypes(
MLIRContext* ctx, std::optional<Location>, RMulOp::Adaptor adaptor,
SmallVectorImpl<Type>& inferredReturnTypes) {
Expand Down
12 changes: 9 additions & 3 deletions lib/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,21 @@ def LWE_AddOp : LWE_BinOp<"add", [SameOperandsAndResultType,Commutative]> {
let summary = "Add two LWE ciphertexts";
}

def LWE_RAddOp : LWE_BinOp<"radd", [SameOperandsAndResultType,Commutative]> {
def LWE_RAddOp : LWE_BinOp<"radd", [SameOperandsAndResultRings, InferTypeOpAdaptor, Commutative]> {
let summary = "Add two RLWE ciphertexts";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];
}

def LWE_RSubOp : LWE_BinOp<"rsub", [SameOperandsAndResultType]> {
def LWE_RSubOp : LWE_BinOp<"rsub", [SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Subtract two RLWE ciphertexts";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];
}

def LWE_RMulOp : LWE_BinOp<"rmul", [SameTypeOperands,InferTypeOpAdaptor, Commutative]> {
def LWE_RMulOp : LWE_BinOp<"rmul", [SameOperandsAndResultRings, InferTypeOpAdaptor, Commutative]> {
let summary = "Multiplies two RLWE ciphertexts";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Openfhe/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:FHEHelpers",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//llvm:Support",
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheDialect.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h"

#include "lib/Dialect/FHEHelpers.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.cpp.inc"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
Expand Down Expand Up @@ -43,6 +44,18 @@ LogicalResult MakeCKKSPackedPlaintextOp::verify() {
return success();
}

LogicalResult AddOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, AddOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

LogicalResult SubOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, SubOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
return inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes);
}

} // namespace openfhe
} // namespace heir
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/Openfhe/IR/OpenfheOps.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_H_
#define LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_H_

#include "lib/Dialect/LWE/IR/LWETraits.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
Expand Down
21 changes: 15 additions & 6 deletions lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ include "OpenfheDialect.td"
include "OpenfheTypes.td"

include "lib/Dialect/LWE/IR/LWETypes.td"
include "lib/Dialect/LWE/IR/LWETraits.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class Openfhe_Op<string mnemonic, list<Trait> traits = []> :
Expand Down Expand Up @@ -46,8 +48,7 @@ class Openfhe_UnaryOp<string mnemonic, list<Trait> traits = []>

class Openfhe_BinaryOp<string mnemonic, list<Trait> traits = []>
: Openfhe_Op<mnemonic, traits # [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>,
Pure
]>{
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Expand Down Expand Up @@ -153,8 +154,16 @@ def DecryptOp : Openfhe_Op<"decrypt", [Pure]> {
let results = (outs NewLWEPlaintext:$plaintext);
}

def AddOp : Openfhe_BinaryOp<"add"> { let summary = "OpenFHE add operation of two ciphertexts."; }
def SubOp : Openfhe_BinaryOp<"sub"> { let summary = "OpenFHE sub operation of two ciphertexts."; }
def AddOp : Openfhe_BinaryOp<"add",
[SameOperandsAndResultRings,
InferTypeOpAdaptor]> {
let summary = "OpenFHE add operation of two ciphertexts.";
}
def SubOp : Openfhe_BinaryOp<"sub",
[SameOperandsAndResultRings,
InferTypeOpAdaptor]> {
let summary = "OpenFHE sub operation of two ciphertexts.";
}

def AddPlainOp : Openfhe_Op<"add_plain",[
Pure,
Expand Down Expand Up @@ -182,9 +191,9 @@ def SubPlainOp : Openfhe_Op<"sub_plain",[
let results = (outs NewLWECiphertext:$output);
}

def MulOp : Openfhe_BinaryOp<"mul"> { let summary = "OpenFHE mul operation of two ciphertexts with relinearization."; }
def MulOp : Openfhe_BinaryOp<"mul", [AllTypesMatch<["lhs", "rhs", "output"]>]> { let summary = "OpenFHE mul operation of two ciphertexts with relinearization."; }

def MulNoRelinOp : Openfhe_Op<"mul_no_relin", [Pure, AllTypesMatch<["lhs", "rhs"]>]> {
def MulNoRelinOp : Openfhe_Op<"mul_no_relin", [Pure, SameOperandsAndResultRings]> {
let summary = "OpenFHE mul operation of two ciphertexts without relinearization.";
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module {
// CHECK-SAME: ([[C:%.+]]: [[S:.*evaluator]], [[X:%.+]]: [[T:!lattigo.rlwe.ciphertext]], [[Y:%.+]]: [[T]])
func.func @test_ops(%x : !ct, %y : !ct) {
// CHECK: %[[v1:.*]] = lattigo.bgv.add [[C]], %[[x:.*]], %[[y:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%add = bgv.add %x, %y : !ct
%add = bgv.add %x, %y : (!ct, !ct) -> !ct
// CHECK: %[[mul:.*]] = lattigo.bgv.mul [[C]], %[[x]], %[[y]]: ([[S]], [[T]], [[T]]) -> [[T]]
%mul = bgv.mul %x, %y : (!ct, !ct) -> !ct1
// CHECK: %[[relin:.*]] = lattigo.bgv.relinearize [[C]], %[[mul]] : ([[S]], [[T]]) -> [[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ module {
// CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]]
%negate = bgv.negate %x : !ct
// CHECK: %[[v2:.*]] = openfhe.add [[C]], %[[x2:.*]], %[[y2:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%add = bgv.add %x, %y : !ct
%add = bgv.add %x, %y : (!ct, !ct) -> !ct
// CHECK: %[[v3:.*]] = openfhe.sub [[C]], %[[x3:.*]], %[[y3:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%sub = bgv.sub %x, %y : !ct
%sub = bgv.sub %x, %y : (!ct, !ct) -> !ct
// CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T2:.*]]
%mul = bgv.mul %x, %y : (!ct, !ct) -> !ct_D3
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]] {index = 4 : i64}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func.func @linear_polynomial(%arg0: !ct_ty, %arg1: !ct_ty, %arg2: !ct_ty, %arg3:
// CHECK: %[[v1:.*]] = openfhe.relin %[[cc]], %[[v0]]
%1 = bgv.relinearize %0 {from_basis = array<i32: 0, 1, 2>, to_basis = array<i32: 0, 1>} : !ct_sq_ty -> !ct_ty
// CHECK: %[[v2:.*]] = openfhe.sub %[[cc]], %[[arg3]], %[[v1]]
%2 = bgv.sub %arg3, %1 : !ct_ty
%2 = bgv.sub %arg3, %1 : (!ct_ty, !ct_ty) -> !ct_ty
// CHECK: %[[v3:.*]] = openfhe.sub %[[cc]], %[[v2]], %[[arg1]]
%3 = bgv.sub %2, %arg1 : !ct_ty
%3 = bgv.sub %2, %arg1 : (!ct_ty, !ct_ty) -> !ct_ty
// CHECK: return %[[v3]]
return %3 : !ct_ty
}
4 changes: 2 additions & 2 deletions tests/Dialect/BGV/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
module {
// CHECK-LABEL: @test_multiply
func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct {
%add = bgv.add %arg0, %arg1 : !ct
%sub = bgv.sub %arg0, %arg1 : !ct
%add = bgv.add %arg0, %arg1 : (!ct, !ct) -> !ct
%sub = bgv.sub %arg0, %arg1 : (!ct, !ct) -> !ct
%neg = bgv.negate %arg0 : !ct

%0 = bgv.mul %arg0, %arg1 : (!ct, !ct) -> !ct1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ module {
// CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]]
%negate = ckks.negate %x : !ct
// CHECK: %[[v2:.*]] = openfhe.add [[C]], %[[x2:.*]], %[[y2:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%add = ckks.add %x, %y : !ct
%add = ckks.add %x, %y : (!ct, !ct) -> !ct
// CHECK: %[[v3:.*]] = openfhe.sub [[C]], %[[x3:.*]], %[[y3:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%sub = ckks.sub %x, %y : !ct
%sub = ckks.sub %x, %y : (!ct, !ct) -> !ct
// CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T2:.*]]
%mul = ckks.mul %x, %y : (!ct, !ct) -> !ct_D3
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]] {index = 4 : i64}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func.func @linear_polynomial(%arg0: !ct_ty, %arg1: !ct_ty, %arg2: !ct_ty, %arg3:
// CHECK: %[[v1:.*]] = openfhe.relin %[[cc]], %[[v0]]
%1 = ckks.relinearize %0 {from_basis = array<i32: 0, 1, 2>, to_basis = array<i32: 0, 1>} : !ct_sq_ty -> !ct_ty
// CHECK: %[[v2:.*]] = openfhe.sub %[[cc]], %[[arg3]], %[[v1]]
%2 = ckks.sub %arg3, %1 : !ct_ty
%2 = ckks.sub %arg3, %1 : (!ct_ty, !ct_ty) -> !ct_ty
// CHECK: %[[v3:.*]] = openfhe.sub %[[cc]], %[[v2]], %[[arg1]]
%3 = ckks.sub %2, %arg1 : !ct_ty
%3 = ckks.sub %2, %arg1 : (!ct_ty, !ct_ty) -> !ct_ty
// CHECK: return %[[v3]]
return %3 : !ct_ty
}
4 changes: 2 additions & 2 deletions tests/Dialect/CKKS/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
module {
// CHECK-LABEL: @test_multiply
func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct {
%add = ckks.add %arg0, %arg1 : !ct
%sub = ckks.sub %arg0, %arg1 : !ct
%add = ckks.add %arg0, %arg1 : (!ct, !ct) -> !ct
%sub = ckks.sub %arg0, %arg1 : (!ct, !ct) -> !ct
%neg = ckks.negate %arg0 : !ct

// CHECK: ring = <coefficientType = !rns.rns<!mod_arith.int<1095233372161 : i64>, !mod_arith.int<1032955396097 : i64>>, polynomialModulus = <1 + x**1024>>
Expand Down
Loading

0 comments on commit 7983559

Please sign in to comment.