Skip to content

Commit

Permalink
Merge pull request #1340 from WoutLegiest:mnist_hl
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721842675
  • Loading branch information
copybara-github committed Jan 31, 2025
2 parents 1098b98 + a2103de commit b063060
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 35 deletions.
165 changes: 133 additions & 32 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,52 @@ static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) {
return type;
}

// Function to check if an operation is allowed to remain in the Arith dialect
static bool allowedRemainArith(Operation *op) {
return llvm::TypeSwitch<Operation *, bool>(op)
.Case<mlir::arith::ConstantOp>([](auto op) {
// This lambda will be called for any of the matched operation types
return true;
})
// Allow memref LoadOp if it comes from a FuncArg or if it comes from
// an allowed alloc memref
// Other cases: Memref comes from function -> need to convert to LWE
.Case<memref::LoadOp>([](memref::LoadOp memrefLoad) {
return memrefLoad.getMemRef().getDefiningOp() != nullptr;
})
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto *defOp = op.getIn().getDefiningOp()) {
return allowedRemainArith(defOp);
}
return false;
})
.Default([](Operation *) {
// Default case for operations that don't match any of the types
return false;
});
}

static bool hasLWEAnnotation(Operation *op) {
return static_cast<bool>(
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
}

static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();
if (!isa<IntegerType>(inputType))
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}

class ArithToCGGITypeConverter : public TypeConverter {
public:
ArithToCGGITypeConverter(MLIRContext *ctx) {
Expand All @@ -43,6 +89,10 @@ class ArithToCGGITypeConverter : public TypeConverter {
addConversion([ctx](ShapedType type) -> Type {
return convertArithLikeToCGGIType(type, ctx);
});

// Target materialization to convert integer constants to LWE ciphertexts
// by creating a trivial LWE ciphertext
addTargetMaterialization(materializeTarget);
}
};

Expand Down Expand Up @@ -167,7 +217,7 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(lhsDefOp)) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
Expand All @@ -176,7 +226,7 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(rhsDefOp)) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
Expand All @@ -191,6 +241,30 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
}
};

struct ConvertAllocOp : public OpConversionPattern<mlir::memref::AllocOp> {
ConvertAllocOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::memref::AllocOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::memref::AllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

for (auto *userOp : op->getUsers()) {
userOp->setAttr("lwe_annotation",
mlir::StringAttr::get(userOp->getContext(), "LWE"));
}

auto lweType = getTypeConverter()->convertType(op.getType());
auto allocOp =
b.create<memref::AllocOp>(op.getLoc(), lweType, op.getOperands());
rewriter.replaceOp(op, allocOp);
return success();
}
};

struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -206,7 +280,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return isa<mlir::arith::ConstantOp>(defOp);
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
}
return false;
});
Expand All @@ -220,42 +294,46 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

target.addDynamicallyLegalOp<memref::AllocOp>([&](Operation *op) {
// Check if all Store ops are constants, if not store op, accepts
// Check if there is at least one Store op that is a constants
return (llvm::all_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return true;
}) &&
llvm::any_of(op->getUses(),
[&](OpOperand &op) {
auto defOp =
dyn_cast<memref::StoreOp>(op.getOwner());
if (defOp) {
return isa<mlir::arith::ConstantOp>(
defOp.getValue().getDefiningOp());
}
return false;
})) ||
// Check if all Store ops are constants or GetGlobals, if not store op,
// accepts Check if there is at least one Store op that is a constants
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
}
return false;
});
auto allStoreOpsAreArith =
llvm::all_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
}
return true;
});

return (allStoreOpsAreArith && containsAnyStoreOp) ||
// The other case: Memref need to be in LWE format
(typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes()));
});

target.addDynamicallyLegalOp<memref::StoreOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

if (auto *defOp = cast<memref::StoreOp>(op).getValue().getDefiningOp()) {
if (isa<mlir::arith::ConstantOp>(defOp)) {
if (isa<mlir::arith::ConstantOp>(defOp) ||
isa<mlir::memref::GetGlobalOp>(defOp)) {
return true;
}
}

return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
return true;
});

// Convert LoadOp if memref comes from an argument
Expand All @@ -264,17 +342,40 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
typeConverter.isLegal(op->getResultTypes())) {
return true;
}
auto loadOp = dyn_cast<memref::LoadOp>(op);

return loadOp.getMemRef().getDefiningOp() != nullptr;
if (dyn_cast<memref::LoadOp>(op).getMemRef().getDefiningOp() == nullptr) {
return false;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

return true;
});

// Convert Dealloc if memref comes from an argument
target.addDynamicallyLegalOp<memref::DeallocOp>([&](Operation *op) {
if (typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes())) {
return true;
}

if (auto lweAttr =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation")) {
return false;
}

return true;
});

patterns.add<
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::LoadOp>, ConvertAllocOp,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase<ArithToCGGIQuart> {
return signalPassFailure();
}

// Remove the unnecessary tensor ops between each converted arith operation.
// Remove the uncessary tensor ops between each converted arith operation.
OpPassManager pipeline("builtin.module");
pipeline.addPass(createCSEPass());
(void)runPipeline(pipeline, getOperation());
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TfheRust/IR/TfheRustOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TfheRust_ScalarBinaryOp<string mnemonic>
AnyTypeOf<[Builtin_Integer, TfheRust_CiphertextType]>:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
let summary = "Arithmetic sub of two tfhe ciphertexts.";
}

def TfheRust_CreateTrivialOp : TfheRust_Op<"create_trivial", [Pure]> {
Expand All @@ -61,8 +62,8 @@ def TfheRust_SubOp : TfheRust_Op<"sub", [
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
TfheRust_CiphertextType:$rhs
TfheRust_CiphertextLikeType:$lhs,
TfheRust_CiphertextLikeType:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
let summary = "Arithmetic sub of two tfhe ciphertexts.";
Expand Down
36 changes: 36 additions & 0 deletions tests/Dialect/Arith/Conversions/ArithToCGGI/loop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: heir-opt --arith-to-cggi %s | FileCheck %s

// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG:.*]]: memref<1x52x!lwe.lwe_ciphertext<encoding = #unspecified_bit_field_encoding>, strided<[?, ?], offset: ?>>) -> [[T:.*]] {
module attributes {tf_saved_model.semantics} {
// CHECK: return %[[ADD:.*]] : [[T]]
memref.global "private" constant @__constant_31x52xi4 : memref<31x52xi4> = dense<"0x00040503040204010206020407020203020301060200010201020006020701040501070402070702030602030100070200040305060207060300010600000602020201060201020103060700020300020307030406040506070202030002060000070504070307050402000702050201010404010706020107060705000703030203030203060503040607070107010504000707010204010304060107050300060202010501050206010306010001030206060204070700040105010202030306050207030101030204020202020000010206050203020000000602030405070300060105000406040103040706030402030501020403070603040101070305000002040301070106000500010303030305010200060206050403010100020502030400010604040405030307070102030300050303040302040507020102070107060301020101020401030101000004070303010303030102030205040504010106010301070306030402020301050707030406040703040706040107020007060707040407010104030506030300030701010401000105050202020604060207030302070406000103040705050002020102000207070106060702070004070001010202000004020002030301030605000101020402030304020602070206020307020603070101040401030107030001050102030603020402040201070603000006010100060003030005070102040104040507030101010704020200060401000401010404040204020006040305010402020102060504000400060503000705040603030505030601010101030404000200010203060303050204030200010705050603040006050304010106010507020201040507050203010503010207030304050300050407020602000606040607030203030701030304010101010701070301040407020307000002050101020604020101070202010507030404050505070007010106020701020405030304000502060304040103010107020203020604010202000102040307020207020002010302030103020100000601000203050506030406070006030202020204030102010402040102040106040001010506030002020200030103000001040207020102040100020202020204020201070003020402010304020201040303020507060204040304020702050001060002060303030707040207020605040002030102020501010403010105010304010602050104070003010403020103060704000204060004050304020401020602040702020302030106020001020102000602070104050107040207070203060203010007020004030506020706030001060000060202020106020102010306070002030002030703040604050607020203000206000007050407030705040200070205020101040401070602010706070500070303020303020306050304060707010701050400070701020401030406010705030006020201050105020601030601000103020606020407070004010501020203030605020703010103020402020202000001020605020302000000060203040507030006010500040604010304070603040203050102040307060304010107030500000204030107010600050001030303030501020006020605040301010002050203040001060404040503030707010203030005030304030204050702010207010706030102010102040103010100000407030301030303010203020504050401010601030107030603040202030105070703040604070304070604010702000706070704040701010403050603030003070101040100010505020202060406020703030207040600010304070505000202010200020707010606070207000407000101020200000402000203030103060500010102040203030402060207020602030702060307010104040103010703000105010203060302040204020107060300000601010006000303000507010204010404050703010101070402020006040100040101040404020402000604030501040202010206050400040006050300070504060303050503060101010103040400020001020306030305020403020001070505060304000605030401010601050702020104050705020301050301020703030405030005040702060200060604060703020303070103030401010101070107030104040702030700000205010102"> {alignment = 64 : i64}
func.func @main(%arg0: memref<1x52xi4, strided<[?, ?], offset: ?>>) -> memref<1x10xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = memref.get_global @__constant_31x52xi4 : memref<31x52xi4>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x31xi32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x10xi32>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x15xi32>
affine.parallel (%arg1) = (0) to (1) {
affine.parallel (%arg2) = (0) to (10) {
memref.store %c0_i32, %alloc[%c0, %arg2] : memref<1x31xi32>
affine.for %arg3 = 0 to 52 {
%1 = memref.load %0[%arg2, %arg3] : memref<31x52xi4>
%2 = memref.load %arg0[%c0, %arg3] : memref<1x52xi4, strided<[?, ?], offset: ?>>
%3 = memref.load %alloc[%c0, %arg2] : memref<1x31xi32>
memref.store %c0_i32, %alloc_1[%c0, %arg2] : memref<1x15xi32>
%4 = arith.extsi %2 : i4 to i8
%5 = arith.extsi %1 : i4 to i8
%6 = arith.muli %4, %5 : i8
%7 = arith.extsi %6 : i8 to i32
%8 = arith.addi %3, %7 : i32
memref.store %8, %alloc_0[%c0, %arg2] : memref<1x10xi32>
%33 = memref.load %alloc_1[%c0, %arg2] : memref<1x15xi32>
}
}
}
memref.dealloc %alloc : memref<1x31xi32>
return %alloc_0 : memref<1x10xi32>
}
}

0 comments on commit b063060

Please sign in to comment.