Skip to content

Commit

Permalink
bgv-to-lattigo: lower client-interface and plain op
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Jan 14, 2025
1 parent a94f4cb commit e6c79ed
Show file tree
Hide file tree
Showing 18 changed files with 494 additions and 90 deletions.
91 changes: 79 additions & 12 deletions lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,26 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

namespace mlir::heir::bgv {

#define GEN_PASS_DEF_BGVTOLATTIGO
#include "lib/Dialect/BGV/Conversions/BGVToLattigo/BGVToLattigo.h.inc"

using ConvertAddOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, AddOp, lattigo::BGVAddOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp, lattigo::BGVAddOp>;
using ConvertSubOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, SubOp, lattigo::BGVSubOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp, lattigo::BGVSubOp>;
using ConvertMulOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, MulOp, lattigo::BGVMulOp>;
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp, lattigo::BGVMulOp>;
using ConvertAddPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
AddPlainOp, lattigo::BGVAddOp>;
using ConvertSubPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
SubPlainOp, lattigo::BGVSubOp>;
using ConvertMulPlainOp = ConvertRlwePlainOp<lattigo::BGVEvaluatorType,
MulPlainOp, lattigo::BGVMulOp>;

using ConvertRelinOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, RelinearizeOp,
lattigo::BGVRelinearizeOp>;
Expand All @@ -44,6 +52,33 @@ using ConvertModulusSwitchOp =
using ConvertRotateOp = ConvertRlweRotateOp<lattigo::BGVEvaluatorType, RotateOp,
lattigo::BGVRotateColumnsOp>;

using ConvertEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
lattigo::RLWEEncryptOp>;
using ConvertDecryptOp =
ConvertRlweUnaryOp<lattigo::RLWEDecryptorType, lwe::RLWEDecryptOp,
lattigo::RLWEDecryptOp>;
using ConvertEncodeOp =
ConvertRlweEncodeOp<lattigo::BGVEncoderType, lattigo::BGVParameterType,
lwe::RLWEEncodeOp, lattigo::BGVEncodeOp,
lattigo::BGVNewPlaintextOp>;
using ConvertDecodeOp =
ConvertRlweDecodeOp<lattigo::BGVEncoderType, lwe::RLWEDecodeOp,
lattigo::BGVDecodeOp, arith::ConstantOp>;

struct ConvertLWEReinterpretUnderlyingType
: public OpConversionPattern<lwe::ReinterpretUnderlyingTypeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
lwe::ReinterpretUnderlyingTypeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// erase reinterpret underlying
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
return success();
}
};

struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -53,30 +88,62 @@ struct BGVToLattigo : public impl::BGVToLattigoBase<BGVToLattigo> {
ConversionTarget target(*context);
target.addLegalDialect<lattigo::LattigoDialect>();
target.addIllegalDialect<bgv::BGVDialect>();
target.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp,
lwe::RLWEEncodeOp>();
target
.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp, lwe::RLWEEncodeOp,
lwe::RLWEDecodeOp, lwe::RAddOp, lwe::RSubOp, lwe::RMulOp,
lwe::ReinterpretUnderlyingTypeOp>();

RewritePatternSet patterns(context);
addStructuralConversionPatterns(typeConverter, patterns, target);

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 &&
mlir::isa<lattigo::BGVEvaluatorType>(
*op.getFunctionType().getInputs().begin());
bool hasCryptoContextArg =
op.getFunctionType().getNumInputs() > 0 &&
containsArgumentOfType<
lattigo::BGVEvaluatorType, lattigo::BGVEncoderType,
lattigo::RLWEEncryptorType, lattigo::RLWEDecryptorType>(op);

return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody()) &&
(!containsDialects<lwe::LWEDialect, bgv::BGVDialect>(op) ||
hasCryptoContextArg);
});

patterns.add<AddEvaluatorArg<bgv::BGVDialect, lattigo::BGVEvaluatorType>,
ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp>(typeConverter,
context);
std::vector<std::pair<Type, OpPredicate>> evaluators;

// param/encoder also needed for the main func
// as there might (not) be ct-pt operations
evaluators = {
{lattigo::BGVEvaluatorType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVParameterType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::BGVEncoderType::get(context),
containsDialects<lwe::LWEDialect, bgv::BGVDialect>},
{lattigo::RLWEEncryptorType::get(context),
containsAnyOperations<lwe::RLWEEncryptOp>},
{lattigo::RLWEDecryptorType::get(context),
containsAnyOperations<lwe::RLWEDecryptOp>},
};

patterns.add<AddEvaluatorArg>(context, evaluators);

patterns.add<ConvertAddOp, ConvertSubOp, ConvertMulOp, ConvertAddPlainOp,
ConvertSubPlainOp, ConvertMulPlainOp, ConvertRelinOp,
ConvertModulusSwitchOp, ConvertRotateOp, ConvertEncryptOp,
ConvertDecryptOp, ConvertEncodeOp, ConvertDecodeOp,
ConvertLWEReinterpretUnderlyingType>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}

// remove unused key args from function types
// in favor of encryptor/decryptor
RewritePatternSet postPatterns(context);
postPatterns.add<RemoveKeyArg<lattigo::RLWESecretKeyType>>(context);
postPatterns.add<RemoveKeyArg<lattigo::RLWEPublicKeyType>>(context);
walkAndApplyPatterns(module, std::move(postPatterns));
}
};

Expand Down
177 changes: 167 additions & 10 deletions lib/Dialect/LWE/Conversions/RlweToLattigo/RlweToLattigo.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,36 @@ FailureOr<Value> getContextualEvaluator(Operation *op) {
return result.value();
}

template <typename Dialect, typename EvaluatorType>
struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
AddEvaluatorArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}
AddEvaluatorArg(mlir::MLIRContext *context,
const std::vector<std::pair<Type, OpPredicate>> &evaluators)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2),
evaluators(evaluators) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!containsDialects<lwe::LWEDialect, Dialect>(op)) {
return failure();
SmallVector<Type, 4> selectedEvaluators;

for (const auto &evaluator : evaluators) {
auto predicate = evaluator.second;
if (predicate(op)) {
selectedEvaluators.push_back(evaluator.first);
}
}

if (selectedEvaluators.empty()) {
return success();
}

auto evaluatorType = EvaluatorType::get(getContext());
FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs() + 1);
newTypes.push_back(evaluatorType);
newTypes.reserve(originalType.getNumInputs() + selectedEvaluators.size());
for (auto evaluatorType : selectedEvaluators) {
newTypes.push_back(evaluatorType);
}
for (auto t : originalType.getInputs()) {
newTypes.push_back(t);
}
Expand All @@ -57,10 +68,61 @@ struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
for (auto evaluatorType : llvm::reverse(selectedEvaluators)) {
block.insertArgument(&block.getArguments().front(), evaluatorType,
op.getLoc());
}
});
return success();
}

private:
std::vector<std::pair<Type, OpPredicate>> evaluators;
};

template <typename KeyType>
struct RemoveKeyArg : public OpConversionPattern<func::FuncOp> {
RemoveKeyArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<int, 1> keyArgIndices;
Block &block = op.getBody().getBlocks().front();
for (auto arg : block.getArguments()) {
if (mlir::isa<KeyType>(arg.getType()) && arg.getUses().empty()) {
keyArgIndices.push_back(arg.getArgNumber());
}
}

if (keyArgIndices.empty()) {
return success();
}

FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs());
for (auto arg : block.getArguments()) {
if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) {
continue;
}
newTypes.push_back(arg.getType());
}
auto newFuncType =
FunctionType::get(getContext(), newTypes, originalType.getResults());
rewriter.modifyOpInPlace(op, [&] {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
for (auto arg : block.getArguments()) {
if (llvm::is_contained(keyArgIndices, arg.getArgNumber())) {
block.eraseArgument(arg.getArgNumber());
}
}
});
return success();
}
};
Expand Down Expand Up @@ -105,6 +167,25 @@ struct ConvertRlweBinOp : public OpConversionPattern<BinOp> {
}
};

template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
struct ConvertRlwePlainOp : public OpConversionPattern<PlainOp> {
using OpConversionPattern<PlainOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
PlainOp op, typename PlainOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;

Value evaluator = result.value();
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getCiphertextInput(), adaptor.getPlaintextInput());
return success();
}
};

template <typename EvaluatorType, typename RlweRotateOp,
typename LattigoRotateOp>
struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
Expand All @@ -130,6 +211,82 @@ struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
}
};

template <typename EvaluatorType, typename ParamType, typename EncodeOp,
typename LattigoEncodeOp, typename AllocOp>
struct ConvertRlweEncodeOp : public OpConversionPattern<EncodeOp> {
using OpConversionPattern<EncodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
EncodeOp op, typename EncodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

FailureOr<Value> result2 =
getContextualEvaluator<ParamType>(op.getOperation());
if (failed(result2)) return result2;
Value params = result2.value();

auto alloc = rewriter.create<AllocOp>(
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
params);

rewriter.replaceOpWithNewOp<LattigoEncodeOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getInput(), alloc);
return success();
}
};

template <typename EvaluatorType, typename DecodeOp, typename LattigoDecodeOp,
typename AllocOp>
struct ConvertRlweDecodeOp : public OpConversionPattern<DecodeOp> {
using OpConversionPattern<DecodeOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
DecodeOp op, typename DecodeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
Value evaluator = result.value();

auto outputType = op.getOutput().getType();
RankedTensorType outputTensorType = dyn_cast<RankedTensorType>(outputType);
bool isScalar = false;
if (!outputTensorType) {
isScalar = true;
outputTensorType = RankedTensorType::get({1}, outputType);
}

APInt zero(getElementTypeOrSelf(outputType).getIntOrFloatBitWidth(), 0);

auto constant = DenseElementsAttr::get(outputTensorType, zero);

auto alloc =
rewriter.create<AllocOp>(op.getLoc(), outputTensorType, constant);

auto decodeOp = rewriter.create<LattigoDecodeOp>(
op.getLoc(), outputTensorType, evaluator, adaptor.getInput(), alloc);

// TODO(#1174): the sin of lwe.reinterpret_underlying_type
if (isScalar) {
SmallVector<Value, 1> indices;
auto index = rewriter.create<arith::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
indices.push_back(index);
auto extract = rewriter.create<tensor::ExtractOp>(
op.getLoc(), decodeOp.getResult(), indices);
rewriter.replaceOp(op, extract.getResult());
} else {
rewriter.replaceOp(op, decodeOp.getResult());
}
return success();
}
};

} // namespace mlir::heir

#endif // LIB_DIALECT_LWE_CONVERSIONS_RLWETOLATTIGOUTILS_RLWETOLATTIGO_H_
2 changes: 1 addition & 1 deletion lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Lattigo_BGVBinaryOp<string mnemonic> :
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertext:$rhs
AnyType:$rhs
);
let results = (outs Lattigo_RLWECiphertext:$output);
}
Expand Down
Loading

0 comments on commit e6c79ed

Please sign in to comment.