Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bgv-to-lattigo: lower client-interface and plain op #1226

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,26 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

namespace mlir::heir::bgv {

#define GEN_PASS_DEF_BGVTOLATTIGO
#include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h.inc"

using ConvertAddOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, AddOp, lattigo::BGVAddOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp, lattigo::BGVAddOp>;
using ConvertSubOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, SubOp, lattigo::BGVSubOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp, lattigo::BGVSubOp>;
using ConvertMulOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, MulOp, lattigo::BGVMulOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp, lattigo::BGVMulOp>;
using ConvertAddPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
AddPlainOp, lattigo::BGVAddOp>;
using ConvertSubPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
SubPlainOp, lattigo::BGVSubOp>;
using ConvertMulPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
MulPlainOp, lattigo::BGVMulOp>;

using ConvertRelinOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, RelinearizeOp,
lattigo::BGVRelinearizeOp>;
Expand All @@ -44,6 +52,33 @@ using ConvertModulusSwitchOp =
using ConvertRotateOp = ConvertRlweRotateOp<lattigo::BGVEvaluatorType, RotateOp,
lattigo::BGVRotateColumnsOp>;

using ConvertEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
lattigo::RLWEEncryptOp>;
using ConvertDecryptOp =
ConvertRlweUnaryOp<lattigo::RLWEDecryptorType, lwe::RLWEDecryptOp,
lattigo::RLWEDecryptOp>;
using ConvertEncodeOp =
ConvertRlweEncodeOp<lattigo::BGVEncoderType, lattigo::BGVParameterType,
lwe::RLWEEncodeOp, lattigo::BGVEncodeOp,
lattigo::BGVNewPlaintextOp>;
using ConvertDecodeOp =
ConvertRlweDecodeOp<lattigo::BGVEncoderType, lwe::RLWEDecodeOp,
lattigo::BGVDecodeOp, arith::ConstantOp>;

struct ConvertLWEReinterpretUnderlyingType
: public OpConversionPattern<lwe::ReinterpretUnderlyingTypeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
lwe::ReinterpretUnderlyingTypeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// erase reinterpret underlying
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
return success();
}
};

struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -53,30 +88,62 @@ struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
ConversionTarget target(*context);
target.addLegalDialect<lattigo::LattigoDialect>();
target.addIllegalDialect<bgv::BGVDialect>();
target.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp,
lwe::RLWEEncodeOp>();
target
.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp, lwe::RLWEEncodeOp,
lwe::RLWEDecodeOp, lwe::RAddOp, lwe::RSubOp, lwe::RMulOp,
lwe::ReinterpretUnderlyingTypeOp>();

RewritePatternSet patterns(context);
addStructuralConversionPatterns(typeConverter, patterns, target);

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 &&
mlir::isa<lattigo::BGVEvaluatorType>(
*op.getFunctionType().getInputs().begin());
bool hasCryptoContextArg =
op.getFunctionType().getNumInputs() > 0 &&
containsArgumentOfType<
lattigo::BGVEvaluatorType, lattigo::BGVEncoderType,
lattigo::RLWEEncryptorType, lattigo::RLWEDecryptorType>(op);

return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody()) &&
(!containsDialects<lwe::LWEDialect, bgv::BGVDialect>(op) ||
hasCryptoContextArg);
});

patterns.add<AddEvaluatorArg<bgv::BGVDialect, lattigo::BGVEvaluatorType>,
ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp>(typeConverter,
context);
std::vector<std::pair<Type, OpPredicate>> evaluators;

// param/encoder also needed for the main func
// as there might (not) be ct-pt operations
evaluators = {
{lattigo::BGVEvaluatorType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVParameterType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVEncoderType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::RLWEEncryptorType::get(context),
containsAnyOperations<lwe::RLWEEncryptOp>},
{lattigo::RLWEDecryptorType::get(context),
containsAnyOperations<lwe::RLWEDecryptOp>},
};

patterns.add<AddEvaluatorArg>(context, evaluators);

patterns.add<ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertAddPlainOp,
ConvertSubPlainOp, ConvertMulPlainOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp, ConvertEncryptOp,
ConvertDecryptOp, ConvertEncodeOp, ConvertDecodeOp,
ConvertLWEReinterpretUnderlyingType>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}

// remove unused key args from function types
// in favor of encryptor/decryptor
RewritePatternSet postPatterns(context);
postPatterns.add<RemoveKeyArg<lattigo::RLWESecretKeyType>>(context);
postPatterns.add<RemoveKeyArg<lattigo::RLWEPublicKeyType>>(context);
walkAndApplyPatterns(module, std::move(postPatterns));
}
};

Expand Down
177 changes: 167 additions & 10 deletions lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,36 @@ FailureOr<Value> getContextualEvaluator(Operation *op) {
return result.value();
}

template <typename Dialect, typename EvaluatorType>
struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
AddEvaluatorArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}
AddEvaluatorArg(mlir::MLIRContext *context,
const std::vector<std::pair<Type, OpPredicate>> &evaluators)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2),
evaluators(evaluators) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!containsDialects<lwe::LWEDialect, Dialect>(op)) {
return failure();
SmallVector<Type, 4> selectedEvaluators;

for (const auto &evaluator : evaluators) {
auto predicate = evaluator.second;
if (predicate(op)) {
selectedEvaluators.push_back(evaluator.first);
}
}

if (selectedEvaluators.empty()) {
return success();
}

auto evaluatorType = EvaluatorType::get(getContext());
FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs() + 1);
newTypes.push_back(evaluatorType);
newTypes.reserve(originalType.getNumInputs() + selectedEvaluators.size());
for (auto evaluatorType : selectedEvaluators) {
newTypes.push_back(evaluatorType);
}
for (auto t : originalType.getInputs()) {
newTypes.push_back(t);
}
Expand All @@ -57,10 +68,61 @@ struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
for (auto evaluatorType : llvm::reverse(selectedEvaluators)) {
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
}
});
return success();
}

private:
std::vector<std::pair<Type, OpPredicate>> evaluators;
};

template <typename KeyType>
struct RemoveKeyArg : public OpConversionPattern<func::FuncOp> {
RemoveKeyArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<int, 1> keyArgIndices;
Block &block = op.getBody().getBlocks().front();
for (auto arg : block.getArguments()) {
if (mlir::isa<KeyType>(arg.getType()) && arg.getUses().empty()) {
keyArgIndices.push_back(arg.getArgNumber());
}
}

if (keyArgIndices.empty()) {
return success();
}

FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs());
for (auto arg : block.getArguments()) {
if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) {
continue;
}
newTypes.push_back(arg.getType());
}
auto newFuncType =
FunctionType::get(getContext(), newTypes, originalType.getResults());
rewriter.modifyOpInPlace(op, [&] {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
for (auto arg : block.getArguments()) {
if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) {
block.eraseArgument(arg.getArgNumber());
}
}
});
return success();
}
};
Expand Down Expand Up @@ -105,6 +167,25 @@ struct ConvertRlweBinOp : public OpConversionPattern<BinOp> {
}
};

template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
struct ConvertRlwePlainOp : public OpConversionPattern<PlainOp> {
using OpConversionPattern<PlainOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
PlainOp op, typename PlainOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;

Value evaluator = result.value();
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getCiphertextInput(), adaptor.getPlaintextInput());
return success();
}
};

template <typename EvaluatorType, typename RlweRotateOp,
typename LattigoRotateOp>
struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
Expand All @@ -130,6 +211,82 @@ struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
}
};

template <typename EvaluatorType, typename ParamType, typename EncodeOp,
typename LattigoEncodeOp, typename AllocOp>
struct ConvertRlweEncodeOp : public OpConversionPattern<EncodeOp> {
using OpConversionPattern<EncodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
EncodeOp op, typename EncodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

FailureOr<Value> result2 =
getContextualEvaluator<ParamType>(op.getOperation());
if (failed(result2)) return result2;
Value params = result2.value();

auto alloc = rewriter.create<AllocOp>(
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
params);

rewriter.replaceOpWithNewOp<LattigoEncodeOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getInput(), alloc);
return success();
}
};

template <typename EvaluatorType, typename DecodeOp, typename LattigoDecodeOp,
typename AllocOp>
struct ConvertRlweDecodeOp : public OpConversionPattern<DecodeOp> {
using OpConversionPattern<DecodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
DecodeOp op, typename DecodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

auto outputType = op.getOutput().getType();
RankedTensorType outputTensorType = dyn_cast<RankedTensorType>(outputType);
bool isScalar = false;
if (!outputTensorType) {
isScalar = true;
outputTensorType = RankedTensorType::get({1}, outputType);
}

APInt zero(getElementTypeOrSelf(outputType).getIntOrFloatBitWidth(), 0);

auto constant = DenseElementsAttr::get(outputTensorType, zero);

auto alloc =
rewriter.create<AllocOp>(op.getLoc(), outputTensorType, constant);

auto decodeOp = rewriter.create<LattigoDecodeOp>(
op.getLoc(), outputTensorType, evaluator, adaptor.getInput(), alloc);

// TODO(#1174): the sin of lwe.reinterpret_underlying_type
if (isScalar) {
SmallVector<Value, 1> indices;
auto index = rewriter.create<arith::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
indices.push_back(index);
auto extract = rewriter.create<tensor::ExtractOp>(
op.getLoc(), decodeOp.getResult(), indices);
rewriter.replaceOp(op, extract.getResult());
} else {
rewriter.replaceOp(op, decodeOp.getResult());
}
return success();
}
};

} // namespace mlir::heir

#endif // LIB_DIALECT_LWE_CONVERSIONS_RLWETOLATTIGOUTILS_RLWETOLATTIGO_H_
2 changes: 1 addition & 1 deletion lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Lattigo_BGVBinaryOp<string mnemonic> :
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertext:$rhs
Lattigo_RLWECiphertextOrPlaintext:$rhs
);
let results = (outs Lattigo_RLWECiphertext:$output);
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoRLWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,6 @@ def Lattigo_RLWECiphertext : Lattigo_RLWEType<"Ciphertext", "ciphertext"> {
let nameSuggestion = "ct";
}

def Lattigo_RLWECiphertextOrPlaintext : AnyTypeOf<[Lattigo_RLWECiphertext, Lattigo_RLWEPlaintext]>;

#endif // LIB_DIALECT_LATTIGO_IR_LATTIGORLWETYPES_TD_
Loading
Loading