Skip to content

Commit

Permalink
Start on the ptxt compabibility. arith-to-cggi oke
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 22, 2025
1 parent d83d267 commit f603b0f
Show file tree
Hide file tree
Showing 14 changed files with 323 additions and 147 deletions.
142 changes: 127 additions & 15 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h"

#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/LogicalResult.h>

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
Expand All @@ -22,7 +25,6 @@ static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type,
lwe::UnspecifiedBitFieldEncodingAttr::get(
ctx, type.getIntOrFloatBitWidth()),
lwe::LWEParamsAttr());
;
}

static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) {
Expand Down Expand Up @@ -118,6 +120,26 @@ struct ConvertExtUIOp : public OpConversionPattern<mlir::arith::ExtUIOp> {
}
};

struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
ConvertExtSIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtSIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ExtSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto outType = convertArithToCGGIType(
cast<IntegerType>(op.getResult().getType()), op->getContext());
auto castOp = b.create<cggi::CastOp>(op.getLoc(), outType, adaptor.getIn());

rewriter.replaceOp(op, castOp);
return success();
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
Expand Down Expand Up @@ -166,6 +188,43 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
}
};

template <typename SourceArithOp, typename TargetModArithOp>
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ConvertArithBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}

using OpConversionPattern<SourceArithOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -176,29 +235,82 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
ConversionTarget target(*context);
target.addLegalDialect<cggi::CGGIDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return isa<mlir::arith::ConstantOp>(defOp);
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
return (isa<IndexType>(op.getValue().getType()));
target.addDynamicallyLegalOp<memref::SubViewOp, memref::CopyOp,
tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>(
[&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

target.addDynamicallyLegalOp<
memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp,
memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) {
target.addDynamicallyLegalOp<memref::AllocOp>([&](Operation *op) {
// Check if all Store ops are constants, if not store op, accepts
// Check if there is at least one Store op that is a constants
return (llvm::all_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return true;
}) &&
llvm::any_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return false;
})) ||
// The other case: Memref need to be in LWE format
(typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes()));
});

target.addDynamicallyLegalOp<memref::StoreOp>([&](Operation *op) {
if (auto *defOp = cast<memref::StoreOp>(op).getValue().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(defOp)) {
return true;
}
}

return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

// Convert LoadOp if memref comes from an argument
target.addDynamicallyLegalOp<memref::LoadOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}
auto loadOp = dyn_cast<memref::LoadOp>(op);

return loadOp.getMemRef().getDefiningOp() != nullptr;
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp,
ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp> >(
typeConverter, context);
Expand Down
49 changes: 25 additions & 24 deletions lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
SmallVector<Value> outputs;

for (int i = 0; i < splitLhs.size(); ++i) {
auto lowSum = b.create<cggi::AddOp>(splitLhs[i], splitRhs[i]);
auto lowSum = b.create<cggi::AddOp>(elemType, splitLhs[i], splitRhs[i]);
auto outputLsb = b.create<cggi::CastOp>(op.getLoc(), realTy, lowSum);
auto outputLsbHigh =
b.create<cggi::CastOp>(op.getLoc(), elemType, outputLsb);
Expand All @@ -374,7 +374,8 @@ struct ConvertQuartAddI final : OpConversionPattern<mlir::arith::AddIOp> {
if (i == 0) {
outputs.push_back(outputLsbHigh);
} else {
auto high = b.create<cggi::AddOp>(outputLsbHigh, carries[i - 1]);
auto high =
b.create<cggi::AddOp>(elemType, outputLsbHigh, carries[i - 1]);
outputs.push_back(high);
}
}
Expand Down Expand Up @@ -422,35 +423,35 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {

// TODO: Implement the real Karatsuba algorithm for 4x4 multiplication.
// First part of Karatsuba algorithm
auto z00 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(z01_p1, z01_p2);
auto z00 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[0]);
auto z02 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[1]);
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(z1a1_p1, z1a1_p2);
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(z1b1_m, z1b0);
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
auto z1b2 = b.create<cggi::MulOp>(elemTy, splitLhs[3], splitRhs[1]);
auto z1b1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[2], splitLhs[3]);
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(z1a1, z1b1);
auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);
auto out3Carry = b.create<cggi::AddOp>(elemTy, z1a1, z1b1);

// Output are now all 16b elements, wants presentation of 4x8b
auto output0Lsb = b.create<cggi::CastOp>(realTy, z00);
Expand All @@ -471,9 +472,9 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto output3Lsb = b.create<cggi::CastOp>(realTy, out3Carry);
auto output3LsbHigh = b.create<cggi::CastOp>(elemTy, output3Lsb);

auto output1 = b.create<cggi::AddOp>(output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(output3LsbHigh, output2Msb);
auto output1 = b.create<cggi::AddOp>(elemTy, output1LsbHigh, output0Msb);
auto output2 = b.create<cggi::AddOp>(elemTy, output2LsbHigh, output1Msb);
auto output3 = b.create<cggi::AddOp>(elemTy, output3LsbHigh, output2Msb);

Value resultVec = constructResultTensor(
rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3});
Expand Down
52 changes: 8 additions & 44 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,48 +42,6 @@ constexpr int kAndLut = 8;
constexpr int kOrLut = 14;
constexpr int kXorLut = 6;

static Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) {
// Only supporting unsigned types because the LWE dialect does not have a
// notion of signedness.
switch (width) {
case 1:
// The minimum bit width of the integer tfhe_rust API is UInt2
// https://docs.rs/tfhe/latest/tfhe/index.html#types
// This may happen if there are no LUT or boolean gate operations that
// require a minimum bit width (e.g. shuffling bits in a program that
// multiplies by two).
LLVM_DEBUG(llvm::dbgs()
<< "Upgrading ciphertext with bit width 1 to UInt2");
[[fallthrough]];
case 2:
return tfhe_rust::EncryptedUInt2Type::get(ctx);
case 3:
return tfhe_rust::EncryptedUInt3Type::get(ctx);
case 4:
return tfhe_rust::EncryptedUInt4Type::get(ctx);
case 8:
return tfhe_rust::EncryptedUInt8Type::get(ctx);
case 10:
return tfhe_rust::EncryptedUInt10Type::get(ctx);
case 12:
return tfhe_rust::EncryptedUInt12Type::get(ctx);
case 14:
return tfhe_rust::EncryptedUInt14Type::get(ctx);
case 16:
return tfhe_rust::EncryptedUInt16Type::get(ctx);
case 32:
return tfhe_rust::EncryptedUInt32Type::get(ctx);
case 64:
return tfhe_rust::EncryptedUInt64Type::get(ctx);
case 128:
return tfhe_rust::EncryptedUInt128Type::get(ctx);
case 256:
return tfhe_rust::EncryptedUInt256Type::get(ctx);
default:
llvm_unreachable("Unsupported bitwidth");
}
}

class CGGIToTfheRustTypeConverter : public TypeConverter {
public:
CGGIToTfheRustTypeConverter(MLIRContext *ctx) {
Expand Down Expand Up @@ -532,6 +490,12 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
hasServerKeyArg);
});

target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
bool hasServerKeyArg =
isa<tfhe_rust::ServerKeyType>(op.getOperand(0).getType());
return hasServerKeyArg;
});

target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<
Expand All @@ -546,8 +510,8 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
// FIXME: still need to update callers to insert the new server key arg, if
// needed and possible.
patterns.add<
AddServerKeyArg, ConvertEncodeOp, ConvertLut2Op, ConvertLut3Op,
ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp,
AddServerKeyArg, AddServerKeyArgCall, ConvertEncodeOp, ConvertLut2Op,
ConvertLut3Op, ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp,
ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
ConvertCGGITRBinOp<cggi::MulOp, tfhe_rust::MulOp>,
ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>, ConvertAndOp,
Expand Down
Loading

0 comments on commit f603b0f

Please sign in to comment.