Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Loop support to arith-to-cggi pass #1340

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 133 additions & 32 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,52 @@ static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) {
return type;
}

// Function to check if an operation is allowed to remain in the Arith dialect
static bool allowedRemainArith(Operation *op) {
return llvm::TypeSwitch<Operation *, bool>(op)
.Case<mlir::arith::ConstantOp>([](auto op) {
// This lambda will be called for any of the matched operation types
return true;
})
// Allow memref LoadOp if it comes from a FuncArg or if it comes from
// an allowed alloc memref
// Other cases: Memref comes from function -> need to convert to LWE
.Case<memref::LoadOp>([](memref::LoadOp memrefLoad) {
return memrefLoad.getMemRef().getDefiningOp() != nullptr;
})
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto *defOp = op.getIn().getDefiningOp()) {
return allowedRemainArith(defOp);
}
return false;
})
.Default([](Operation *) {
// Default case for operations that don't match any of the types
return false;
});
}

static bool hasLWEAnnotation(Operation *op) {
return static_cast<bool>(
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
}

static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();
if (!isa<IntegerType>(inputType))
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}

class ArithToCGGITypeConverter : public TypeConverter {
public:
ArithToCGGITypeConverter(MLIRContext *ctx) {
Expand All @@ -43,6 +89,10 @@ class ArithToCGGITypeConverter : public TypeConverter {
addConversion([ctx](ShapedType type) -> Type {
return convertArithLikeToCGGIType(type, ctx);
});

// Target materialization to convert integer constants to LWE ciphertexts
// by creating a trivial LWE ciphertext
addTargetMaterialization(materializeTarget);
}
};

Expand Down Expand Up @@ -167,7 +217,7 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(lhsDefOp)) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
Expand All @@ -176,7 +226,7 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(rhsDefOp)) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
Expand All @@ -191,6 +241,30 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
}
};

struct ConvertAllocOp : public OpConversionPattern<mlir::memref::AllocOp> {
ConvertAllocOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::memref::AllocOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::memref::AllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

for (auto *userOp : op->getUsers()) {
userOp->setAttr("lwe_annotation",
mlir::StringAttr::get(userOp->getContext(), "LWE"));
}

auto lweType = getTypeConverter()->convertType(op.getType());
auto allocOp =
b.create<memref::AllocOp>(op.getLoc(), lweType, op.getOperands());
rewriter.replaceOp(op, allocOp);
return success();
}
};

struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -206,7 +280,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return isa<mlir::arith::ConstantOp>(defOp);
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
}
return false;
});
Expand All @@ -220,42 +294,46 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

target.addDynamicallyLegalOp<memref::AllocOp>([&](Operation *op) {
// Check if all Store ops are constants, if not store op, accepts
// Check if there is at least one Store op that is a constants
return (llvm::all_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return true;
}) &&
llvm::any_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return false;
})) ||
// Check if all Store ops are constants or GetGlobals, if not store op,
// accepts Check if there is at least one Store op that is a constants
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
}
return false;
});
auto allStoreOpsAreArith =
llvm::all_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
}
return true;
});

return (allStoreOpsAreArith && containsAnyStoreOp) ||
// The other case: Memref need to be in LWE format
(typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes()));
});

target.addDynamicallyLegalOp<memref::StoreOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

if (auto *defOp = cast<memref::StoreOp>(op).getValue().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(defOp)) {
if (isa<mlir::arith::ConstantOp>(defOp) ||
isa<mlir::memref::GetGlobalOp>(defOp)) {
return true;
}
}

return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
return true;
});

// Convert LoadOp if memref comes from an argument
Expand All @@ -264,17 +342,40 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
typeConverter.isLegal(op->getResultTypes())) {
return true;
}
auto loadOp = dyn_cast<memref::LoadOp>(op);

return loadOp.getMemRef().getDefiningOp() != nullptr;
if (dyn_cast<memref::LoadOp>(op).getMemRef().getDefiningOp() == nullptr) {
return false;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

return true;
});

// Convert Dealloc if memref comes from an argument
target.addDynamicallyLegalOp<memref::DeallocOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

return true;
});

patterns.add<
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::LoadOp>, ConvertAllocOp,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {
return signalPassFailure();
}

// Remove the unnecessary tensor ops between each converted arith operation.
// Remove the uncessary tensor ops between each converted arith operation.
OpPassManager pipeline("builtin.module");
pipeline.addPass(createCSEPass());
(void)runPipeline(pipeline, getOperation());
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TfheRust/IR/TfheRustOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TfheRust_ScalarBinaryOp<string mnemonic>
AnyTypeOf<[Builtin_Integer, TfheRust_CiphertextType]>:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
let summary = "Arithmetic sub of two tfhe ciphertexts.";
}

def TfheRust_CreateTrivialOp : TfheRust_Op<"create_trivial", [Pure]> {
Expand All @@ -61,8 +62,8 @@ def TfheRust_SubOp : TfheRust_Op<"sub", [
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
TfheRust_CiphertextType:$rhs
TfheRust_CiphertextLikeType:$lhs,
TfheRust_CiphertextLikeType:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
let summary = "Arithmetic sub of two tfhe ciphertexts.";
Expand Down
36 changes: 36 additions & 0 deletions tests/Dialect/Arith/Conversions/ArithToCGGI/loop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: heir-opt --arith-to-cggi %s | FileCheck %s

// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG:.*]]: memref<1x52x!lwe.lwe_ciphertext<encoding = #unspecified_bit_field_encoding>, strided<[?, ?], offset: ?>>) -> [[T:.*]] {
module attributes {tf_saved_model.semantics} {
// CHECK: return %[[ADD:.*]] : [[T]]
memref.global "private" constant @__constant_31x52xi4 : memref<31x52xi4> = dense<"0x00040503040204010206020407020203020301060200010201020006020701040501070402070702030602030100070200040305060207060300010600000602020201060201020103060700020300020307030406040506070202030002060000070504070307050402000702050201010404010706020107060705000703030203030203060503040607070107010504000707010204010304060107050300060202010501050206010306010001030206060204070700040105010202030306050207030101030204020202020000010206050203020000000602030405070300060105000406040103040706030402030501020403070603040101070305000002040301070106000500010303030305010200060206050403010100020502030400010604040405030307070102030300050303040302040507020102070107060301020101020401030101000004070303010303030102030205040504010106010301070306030402020301050707030406040703040706040107020007060707040407010104030506030300030701010401000105050202020604060207030302070406000103040705050002020102000207070106060702070004070001010202000004020002030301030605000101020402030304020602070206020307020603070101040401030107030001050102030603020402040201070603000006010100060003030005070102040104040507030101010704020200060401000401010404040204020006040305010402020102060504000400060503000705040603030505030601010101030404000200010203060303050204030200010705050603040006050304010106010507020201040507050203010503010207030304050300050407020602000606040607030203030701030304010101010701070301040407020307000002050101020604020101070202010507030404050505070007010106020701020405030304000502060304040103010107020203020604010202000102040307020207020002010302030103020100000601000203050506030406070006030202020204030102010402040102040106040001010506030002020200030103000001040207020102040100020202020204020201070003020402010304020201040303020507060204040304020702050001060002060303030707040207020605040002030102020501010403010105010304010602050104070003010403020103060704000204060004050304020401020602040702020302030106020001020102000602070104050107040207070203060203010007020004030506020706030001060000060202020106020102010306070002030002030703040604050607020203000206000007050407030705040200070205020101040401070602010706070500070303020303020306050304060707010701050400070701020401030406010705030006020201050105020601030601000103020606020407070004010501020203030605020703010103020402020202000001020605020302000000060203040507030006010500040604010304070603040203050102040307060304010107030500000204030107010600050001030303030501020006020605040301010002050203040001060404040503030707010203030005030304030204050702010207010706030102010102040103010100000407030301030303010203020504050401010601030107030603040202030105070703040604070304070604010702000706070704040701010403050603030003070101040100010505020202060406020703030207040600010304070505000202010200020707010606070207000407000101020200000402000203030103060500010102040203030402060207020602030702060307010104040103010703000105010203060302040204020107060300000601010006000303000507010204010404050703010101070402020006040100040101040404020402000604030501040202010206050400040006050300070504060303050503060101010103040400020001020306030305020403020001070505060304000605030401010601050702020104050705020301050301020703030405030005040702060200060604060703020303070103030401010101070107030104040702030700000205010102"> {alignment = 64 : i64}
func.func @main(%arg0: memref<1x52xi4, strided<[?, ?], offset: ?>>) -> memref<1x10xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = memref.get_global @__constant_31x52xi4 : memref<31x52xi4>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x31xi32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x10xi32>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x15xi32>
affine.parallel (%arg1) = (0) to (1) {
affine.parallel (%arg2) = (0) to (10) {
memref.store %c0_i32, %alloc[%c0, %arg2] : memref<1x31xi32>
affine.for %arg3 = 0 to 52 {
%1 = memref.load %0[%arg2, %arg3] : memref<31x52xi4>
%2 = memref.load %arg0[%c0, %arg3] : memref<1x52xi4, strided<[?, ?], offset: ?>>
%3 = memref.load %alloc[%c0, %arg2] : memref<1x31xi32>
memref.store %c0_i32, %alloc_1[%c0, %arg2] : memref<1x15xi32>
%4 = arith.extsi %2 : i4 to i8
%5 = arith.extsi %1 : i4 to i8
%6 = arith.muli %4, %5 : i8
%7 = arith.extsi %6 : i8 to i32
%8 = arith.addi %3, %7 : i32
memref.store %8, %alloc_0[%c0, %arg2] : memref<1x10xi32>
%33 = memref.load %alloc_1[%c0, %arg2] : memref<1x15xi32>
}
}
}
memref.dealloc %alloc : memref<1x31xi32>
return %alloc_0 : memref<1x10xi32>
}
}
Loading