Skip to content

Commit

Permalink
Tfhe-rs Ops with the scalar definition
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 27, 2025
1 parent 28a0023 commit 1195971
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
29 changes: 20 additions & 9 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
serverKey, adaptor.getC(), b.getIndexAttr(2));
auto shiftedB = b.create<tfhe_rust::ScalarLeftShiftOp>(
serverKey, adaptor.getB(), b.getIndexAttr(1));
auto summedBC = b.create<tfhe_rust::AddOp>(serverKey, shiftedC, shiftedB);
auto summedABC =
b.create<tfhe_rust::AddOp>(serverKey, summedBC, adaptor.getA());
auto outputType =
getTypeConverter()->convertType(shiftedB.getResult().getType());
auto summedBC =
b.create<tfhe_rust::AddOp>(outputType, serverKey, shiftedC, shiftedB);
auto summedABC = b.create<tfhe_rust::AddOp>(outputType, serverKey, summedBC,
adaptor.getA());

rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, summedABC, lut));
Expand Down Expand Up @@ -205,8 +208,10 @@ struct ConvertLut2Op : public OpConversionPattern<cggi::Lut2Op> {
// Construct input = b << 1 + a
auto shiftedB = b.create<tfhe_rust::ScalarLeftShiftOp>(
serverKey, adaptor.getB(), b.getIndexAttr(1));
auto summedBA =
b.create<tfhe_rust::AddOp>(serverKey, shiftedB, adaptor.getA());

auto summedBA = b.create<tfhe_rust::AddOp>(
getTypeConverter()->convertType(shiftedB.getResult().getType()),
serverKey, shiftedB, adaptor.getA());

rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, summedBA, lut));
Expand All @@ -230,7 +235,11 @@ static LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs,
// Construct input = rhs << 1 + lhs
auto shiftedRhs =
b.create<tfhe_rust::ScalarLeftShiftOp>(serverKey, rhs, b.getIndexAttr(1));
auto input = b.create<tfhe_rust::AddOp>(serverKey, shiftedRhs, lhs);

CGGIToTfheRustTypeConverter typeConverter(op->getContext());
auto outputType = typeConverter.convertType(shiftedRhs.getResult().getType());
auto input =
b.create<tfhe_rust::AddOp>(outputType, serverKey, shiftedRhs, lhs);
rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, input, lutOp));
return success();
Expand All @@ -248,9 +257,11 @@ struct ConvertCGGITRBinOp : public OpConversionPattern<BinOp> {
if (failed(result)) return result;

Value serverKey = result.value();

rewriter.replaceOp(op, b.create<TfheRustBinOp>(serverKey, adaptor.getLhs(),
adaptor.getRhs()));
CGGIToTfheRustTypeConverter typeConverter(op->getContext());
auto outputType = typeConverter.convertType(op.getResult().getType());
rewriter.replaceOp(
op, b.create<TfheRustBinOp>(outputType, serverKey, adaptor.getLhs(),
adaptor.getRhs()));
return success();
}
};
Expand Down
37 changes: 33 additions & 4 deletions lib/Dialect/TfheRust/IR/TfheRustOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include "TfheRustTypes.td"

include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand All @@ -20,7 +21,8 @@ class TfheRust_Op<string mnemonic, list<Trait> traits = []> :
class TfheRust_BinaryOp<string mnemonic>
: TfheRust_Op<mnemonic, [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
Commutative,
ElementwiseMappable,
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
Expand All @@ -30,16 +32,43 @@ class TfheRust_BinaryOp<string mnemonic>
let results = (outs TfheRust_CiphertextLikeType:$output);
}

class TfheRust_ScalarBinaryOp<string mnemonic>
: TfheRust_Op<mnemonic, [
Pure,
Commutative,
ElementwiseMappable,
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
AnyTypeOf<[Builtin_Integer, TfheRust_CiphertextType]>:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
}

def TfheRust_CreateTrivialOp : TfheRust_Op<"create_trivial", [Pure]> {
let arguments = (ins TfheRust_ServerKey:$serverKey, AnyInteger:$value);
let results = (outs TfheRust_CiphertextLikeType:$output);
let hasCanonicalizer = 1;
}

def TfheRust_BitAndOp : TfheRust_BinaryOp<"bitand"> { let summary = "Logical AND of two tfhe ciphertexts."; }
def TfheRust_AddOp : TfheRust_BinaryOp<"add"> { let summary = "Arithmetic add of two tfhe ciphertexts."; }
def TfheRust_SubOp : TfheRust_BinaryOp<"sub"> { let summary = "Arithmetic sub of two tfhe ciphertexts."; }
def TfheRust_MulOp : TfheRust_BinaryOp<"mul"> { let summary = "Arithmetic mul of two tfhe ciphertexts."; }
def TfheRust_AddOp : TfheRust_ScalarBinaryOp<"add"> { let summary = "Arithmetic add of two tfhe ciphertexts."; }
def TfheRust_MulOp : TfheRust_ScalarBinaryOp<"mul"> { let summary = "Arithmetic mul of two tfhe ciphertexts."; }

def TfheRust_SubOp : TfheRust_Op<"sub", [
Pure,
AllTypesMatch<["lhs","rhs","output"]>,
ElementwiseMappable
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
TfheRust_CiphertextType:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
let summary = "Arithmetic sub of two tfhe ciphertexts.";
}


def TfheRust_ScalarLeftShiftOp : TfheRust_Op<"scalar_left_shift", [
Expand Down
1 change: 1 addition & 0 deletions lib/Target/TfheRustHL/TfheRustHLEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::FromElementsOp op) {
return success();
}

// Need to produce a
LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "cggi-to-tfhe-rust"

namespace mlir {
namespace heir {

Expand Down

0 comments on commit 1195971

Please sign in to comment.