Skip to content

Commit

Permalink
Working Small quarter emit
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest authored and WoutLegiest committed Jan 14, 2025
1 parent 5ba244d commit 7bcfbe8
Show file tree
Hide file tree
Showing 19 changed files with 1,059 additions and 22 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/run_rust_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust/...)" \
-c fastbuild \
--sandbox_writable_path=$HOME/.cargo \
"$@"

bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust_hl/cpu/...)" \
| xargs bazel test \
--noincompatible_strict_action_env \
--test_timeout=180 \
-c fastbuild \
--sandbox_writable_path=$HOME/.cargo \
"$@"
1 change: 1 addition & 0 deletions lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ add_subdirectory(Jaxite)
add_subdirectory(Metadata)
add_subdirectory(OpenFhePke)
add_subdirectory(TfheRust)
add_subdirectory(TfheRustHL)
add_subdirectory(TfheRustBool)
add_subdirectory(Verilog)
32 changes: 16 additions & 16 deletions lib/Target/TfheRust/TfheRustEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@ LogicalResult TfheRustEmitter::emitBlock(::mlir::Operation *op) {
for (size_t level = 0; level < levels.size(); ++level) {
os << llvm::formatv(
"run_level({1}, &mut temp_nodes, &mut luts, &LEVEL_{0});\n", level,
serverKeyArg_);
serverKeyArg);
}
}
// Continue to emit the block.
return emitBlock(nextOp);
}

LogicalResult TfheRustEmitter::translateBlock(Block &block) {
if (useLevels_) {
if (useLevels) {
Operation *op = &block.getOperations().front();
return emitBlock(op);
}
Expand Down Expand Up @@ -298,13 +298,13 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) {
}
os << ",\n";
if (isa<tfhe_rust::ServerKeyType>(arg.getType())) {
serverKeyArg_ = argName;
serverKeyArg = argName;
}
}
os.unindent();
os << ")";

if (serverKeyArg_.empty()) {
if (serverKeyArg.empty()) {
return funcOp.emitWarning() << "expected server key function argument to "
"create default ciphertexts";
}
Expand Down Expand Up @@ -336,7 +336,7 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) {
// Create a global temp_nodes hashmap for any created SSA values.
// TODO(#462): Insert block argument that are encrypted ints into
// temp_nodes.
if (useLevels_) {
if (useLevels) {
os << "let mut temp_nodes : HashMap<usize, Ciphertext> = "
"HashMap::new();\n";
os << "let mut luts : HashMap<&str, LookupTableOwned> = "
Expand Down Expand Up @@ -378,7 +378,7 @@ LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) {
res = llvm::formatv("core::array::from_fn(|i{0}| {1})", i--, res);
}
return res;
} else if (isLevelledOp(value.getDefiningOp()) && useLevels_) {
} else if (isLevelledOp(value.getDefiningOp()) && useLevels) {
// This is from a levelled op stored in temp nodes.
return std::string(
llvm::formatv("temp_nodes[&{0}]",
Expand Down Expand Up @@ -424,7 +424,7 @@ LogicalResult TfheRustEmitter::printSksMethod(
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto valueStr = variableNames->getNameForValue(value);
if (isa<LookupTableType>(value.getType()) && useLevels_) {
if (isa<LookupTableType>(value.getType()) && useLevels) {
valueStr = "luts[\"" + variableNames->getNameForValue(value) + "\"]";
}
std::string prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
Expand All @@ -436,7 +436,7 @@ LogicalResult TfheRustEmitter::printSksMethod(

// Insert ciphertext results into temp_nodes so that the levelled ops can
// reference them.
if (usedByLevelledOp(result) && useLevels_) {
if (usedByLevelledOp(result) && useLevels) {
os << llvm::formatv("temp_nodes.insert({0}, {1}.clone());\n",
variableNames->getIntForValue(result),
variableNames->getNameForValue(result));
Expand Down Expand Up @@ -470,15 +470,15 @@ LogicalResult TfheRustEmitter::printOperation(GenerateLookupTableOp op) {
uint64_t truthTable = op.getTruthTable().getUInt();
auto result = op.getResult();

if (useLevels_) {
if (useLevels) {
os << "luts.insert(\"" << variableNames->getNameForValue(result) << "\", ";
} else {
emitAssignPrefix(result);
}
os << variableNames->getNameForValue(sks) << ".generate_lookup_table(";
os << "|x| (" << std::to_string(truthTable) << " >> x) & 1)";

if (useLevels_) {
if (useLevels) {
os << ")";
}
os << ";\n";
Expand Down Expand Up @@ -684,7 +684,7 @@ void TfheRustEmitter::printStoreOp(memref::StoreOp op,
LogicalResult TfheRustEmitter::printOperation(memref::StoreOp op) {
auto valueToStore = variableNames->getNameForValue(op.getValueToStore());

if (isLevelledOp(op.getValueToStore().getDefiningOp()) && useLevels_) {
if (isLevelledOp(op.getValueToStore().getDefiningOp()) && useLevels) {
valueToStore =
llvm::formatv("temp_nodes[&{0}].clone()",
variableNames->getIntForValue(op.getValueToStore()));
Expand Down Expand Up @@ -727,15 +727,15 @@ void TfheRustEmitter::printLoadOp(memref::LoadOp op) {
LogicalResult TfheRustEmitter::printOperation(memref::LoadOp op) {
// If the load op result is used in a levelled op, insert it into the
// temp_nodes map.
if (usedByLevelledOp(op) && useLevels_) {
if (usedByLevelledOp(op) && useLevels) {
os << llvm::formatv("temp_nodes.insert({0}, ",
variableNames->getIntForValue(op.getResult()));
printLoadOp(op);
os << ".clone());\n";
}

// If any uses are outside the levelled op, also assign it it's SSA value.
if (usedByNonLevelledOp(op) || !useLevels_) {
if (usedByNonLevelledOp(op) || !useLevels) {
emitAssignPrefix(op.getResult());
bool isRef =
isa<tfhe_rust::TfheRustDialect>(op.getResult().getType().getDialect());
Expand Down Expand Up @@ -791,9 +791,9 @@ FailureOr<std::string> TfheRustEmitter::convertType(Type type) {

FailureOr<std::string> TfheRustEmitter::defaultValue(Type type) {
if (type.hasTrait<EncryptedInteger>()) {
if (serverKeyArg_.empty()) return failure();
if (serverKeyArg.empty()) return failure();
return std::string(
llvm::formatv("{0}.create_trivial(0 as u64)", serverKeyArg_));
llvm::formatv("{0}.create_trivial(0 as u64)", serverKeyArg));
};
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<IntegerType>([&](IntegerType type) { return std::string("0"); })
Expand All @@ -817,7 +817,7 @@ LogicalResult TfheRustEmitter::emitType(Type type) {
TfheRustEmitter::TfheRustEmitter(raw_ostream &os,
SelectVariableNames *variableNames,
bool useLevels)
: useLevels_(useLevels), os(os), variableNames(variableNames) {}
: useLevels(useLevels), os(os), variableNames(variableNames) {}
} // namespace tfhe_rust
} // namespace heir
} // namespace mlir
4 changes: 2 additions & 2 deletions lib/Target/TfheRust/TfheRustEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TfheRustEmitter {

private:
// Whether to execute levelled operations in parallel.
bool useLevels_;
bool useLevels;

/// Output stream to emit to.
raw_indented_ostream os;
Expand All @@ -53,7 +53,7 @@ class TfheRustEmitter {
SelectVariableNames *variableNames;

// Server key arg to create default values when initializing arrays
std::string serverKeyArg_;
std::string serverKeyArg;

// Functions for printing individual ops
LogicalResult printOperation(::mlir::ModuleOp op);
Expand Down
57 changes: 54 additions & 3 deletions lib/Target/TfheRust/Utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "lib/Target/TfheRust/Utils.h"

#include "lib/Dialect/TfheRust/IR/TfheRustOps.h"
#include "lib/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "lib/Dialect/TfheRustBool/IR/TfheRustBoolOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
Expand All @@ -16,18 +17,21 @@ namespace mlir {
namespace heir {
namespace tfhe_rust {

// TODO: Fix this function to match the list of implemented ops
LogicalResult canEmitFuncForTfheRust(func::FuncOp &funcOp) {
WalkResult failIfInterrupted = funcOp.walk([&](Operation *op) {
return TypeSwitch<Operation *, WalkResult>(op)
// This list should match the list of implemented overloads of
// `printOperation`.
.Case<ModuleOp, func::FuncOp, func::ReturnOp, affine::AffineForOp,
affine::AffineYieldOp, arith::ConstantOp, arith::IndexCastOp,
affine::AffineYieldOp, affine::AffineLoadOp,
affine::AffineStoreOp, arith::ConstantOp, arith::IndexCastOp,
arith::ShLIOp, arith::AndIOp, arith::ShRSIOp, arith::TruncIOp,
tensor::ExtractOp, tensor::FromElementsOp, memref::AllocOp,
memref::DeallocOp, memref::GetGlobalOp, memref::LoadOp,
memref::StoreOp, AddOp, BitAndOp, CreateTrivialOp,
memref::StoreOp, AddOp, SubOp, BitAndOp, CreateTrivialOp,
ApplyLookupTableOp, GenerateLookupTableOp, ScalarLeftShiftOp,
ScalarRightShiftOp, CastOp, MulOp,
::mlir::heir::tfhe_rust_bool::CreateTrivialOp,
::mlir::heir::tfhe_rust_bool::AndOp,
::mlir::heir::tfhe_rust_bool::PackedOp,
Expand All @@ -42,7 +46,8 @@ LogicalResult canEmitFuncForTfheRust(func::FuncOp &funcOp) {
llvm::errs()
<< "Skipping function " << funcOp.getName()
<< " which cannot be emitted because it has an unsupported op: "
<< *op << "\n";
<< *op << "\n"
<< "Origin: TfheRust/Utils.cpp:canEmitFuncForTfheRust\n";
return WalkResult::interrupt();
});
});
Expand All @@ -51,6 +56,52 @@ LogicalResult canEmitFuncForTfheRust(func::FuncOp &funcOp) {
return success();
}

int16_t getTfheRustBitWidth(Type type) {
if (isa<tfhe_rust::EncryptedUInt2Type>(type)) {
return 2;
}
if (isa<tfhe_rust::EncryptedUInt3Type>(type)) {
return 3;
}
if (isa<tfhe_rust::EncryptedUInt4Type>(type)) {
return 4;
}
if (isa<tfhe_rust::EncryptedUInt8Type>(type) ||
isa<tfhe_rust::EncryptedInt8Type>(type)) {
return 8;
}
if (isa<tfhe_rust::EncryptedUInt10Type>(type)) {
return 10;
}
if (isa<tfhe_rust::EncryptedUInt12Type>(type)) {
return 12;
}
if (isa<tfhe_rust::EncryptedUInt14Type>(type)) {
return 14;
}
if (isa<tfhe_rust::EncryptedUInt16Type>(type) ||
isa<tfhe_rust::EncryptedInt16Type>(type)) {
return 16;
}
if (isa<tfhe_rust::EncryptedUInt32Type>(type) ||
isa<tfhe_rust::EncryptedInt32Type>(type)) {
return 32;
}
if (isa<tfhe_rust::EncryptedUInt64Type>(type) ||
isa<tfhe_rust::EncryptedInt64Type>(type)) {
return 64;
}
if (isa<tfhe_rust::EncryptedUInt128Type>(type) ||
isa<tfhe_rust::EncryptedInt128Type>(type)) {
return 128;
}
if (isa<tfhe_rust::EncryptedUInt256Type>(type) ||
isa<tfhe_rust::EncryptedInt256Type>(type)) {
return 256;
}
return -1;
}

} // namespace tfhe_rust
} // namespace heir
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Target/TfheRust/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace tfhe_rust {
// warning and return success. This is because some functions are left
// over during compilation.
::mlir::LogicalResult canEmitFuncForTfheRust(::mlir::func::FuncOp &funcOp);
int16_t getTfheRustBitWidth(Type type);

} // namespace tfhe_rust
} // namespace heir
Expand Down
33 changes: 33 additions & 0 deletions lib/Target/TfheRustHL/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# TfheRustHL Emitter

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "TfheRustHLEmitter",
srcs = ["TfheRustHLEmitter.cpp"],
hdrs = [
"TfheRustHLEmitter.h",
"TfheRustHLTemplates.h",
],
deps = [
"@heir//lib/Analysis/SelectVariableNames",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Target/TfheRust:Utils",
"@heir//lib/Transforms/MemrefToArith:Utils",
"@heir//lib/Utils:TargetUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TranslateLib",
],
)
11 changes: 11 additions & 0 deletions lib/Target/TfheRustHL/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

add_mlir_library(HEIRTfheRusHLEmitter
TfheRusHLEmitter.cpp


LINK_LIBS PUBLIC
HEIRTfheRustBool
MLIRIR
MLIRInferTypeOpInterface
)
target_link_libraries(HEIRTarget INTERFACE HEIRTfheRusHLEmitter)
Loading

0 comments on commit 7bcfbe8

Please sign in to comment.