diff --git a/heir_py/decorator.py b/heir_py/decorator.py index 7ff145eaf..0bb645967 100644 --- a/heir_py/decorator.py +++ b/heir_py/decorator.py @@ -1,5 +1,14 @@ """The decorator entry point for the frontend.""" +import os +import uuid +import weakref +import collections +import functools + +import numba +from numba.core import types, errors, utils, config + from abc import ABC, abstractmethod from typing import Optional @@ -7,6 +16,19 @@ from heir_py import openfhe_config from heir_py.pipeline import CompilationResult, run_compiler +# Exported symbols +from numba.core.typing.typeof import typeof_impl # noqa: F401 +from numba.core.typing.asnumbatype import as_numba_type # noqa: F401 +from numba.core.typing.templates import infer, infer_getattr # noqa: F401 +from numba.core.imputils import ( # noqa: F401 + lower_builtin, lower_getattr, lower_getattr_generic, # noqa: F401 + lower_setattr, lower_setattr_generic, lower_cast) # noqa: F401 +from numba.core.datamodel import models # noqa: F401 +from numba.core.datamodel import register_default as register_model # noqa: F401, E501 +from numba.core.pythonapi import box, unbox, reflect, NativeValue # noqa: F401 +from numba._helperlib import _import_cython_function # noqa: F401 +from numba.core.serialize import ReduceMixin + class CompilationResultInterface(ABC): @@ -77,7 +99,7 @@ def wrapper(arg, *, crypto_context=None, public_key=None): return wrapper - if key == self.compilation_result.func_name: + if key == self.compilation_result.func_name or key == "eval": fn = self.compilation_result.main_func def wrapper(*args, crypto_context=None): @@ -102,9 +124,10 @@ def __call__(self, *args, **kwargs): ) args_encrypted = [ - getattr(self, f"encrypt_{arg_name}")(arg) - for arg_name, arg in zip(arg_names, args) + getattr(self, f"encrypt_{arg_name}")(arg) if i in self.compilation_result.secret_args else arg + for i, (arg_name, arg) in enumerate(zip(arg_names, args)) ] + result_encrypted = getattr(self, self.compilation_result.func_name)( *args_encrypted ) @@ -112,9 +135,10 @@ def __call__(self, *args, **kwargs): def compile( + scheme: str = "bgv", backend: str = "openfhe", - backend_config: Optional[openfhe_config.OpenFHEConfig] = None, - heir_config: Optional[_heir_config.HEIRConfig] = None, + backend_config: Optional[openfhe_config.OpenFHEConfig] = openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG, + heir_config: Optional[_heir_config.HEIRConfig] = _heir_config.DEVELOPMENT_HEIR_CONFIG, debug : Optional[bool] = False ): """Compile a function to its private equivalent in FHE. @@ -134,6 +158,7 @@ def compile( def decorator(func): compilation_result = run_compiler( func, + scheme, openfhe_config=backend_config or openfhe_config.from_os_env(), heir_config=heir_config or _heir_config.from_os_env(), debug = debug @@ -144,3 +169,143 @@ def decorator(func): raise ValueError(f"Unknown backend: {backend}") return decorator + + + +class _MLIR(ReduceMixin): + """ + Dummy callable for _MLIR + """ + _memo = weakref.WeakValueDictionary() + # hold refs to last N functions deserialized, retaining them in _memo + # regardless of whether there is another reference + _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE) + + __uuid = None + + def __init__(self, name, defn, prefer_literal=False, **kwargs): + self._ctor_kwargs = kwargs + self._name = name + self._defn = defn + self._prefer_literal = prefer_literal + functools.update_wrapper(self, defn) + + @property + def _uuid(self): + """ + An instance-specific UUID, to avoid multiple deserializations of + a given instance. + + Note this is lazily-generated, for performance reasons. + """ + u = self.__uuid + if u is None: + u = str(uuid.uuid1()) + self._set_uuid(u) + return u + + def _set_uuid(self, u): + assert self.__uuid is None + self.__uuid = u + self._memo[u] = self + self._recent.append(self) + + def _register(self): + # _ctor_kwargs + from numba.core.typing.templates import (make_intrinsic_template, + infer_global) + + template = make_intrinsic_template(self, self._defn, self._name, + prefer_literal=self._prefer_literal, + kwargs=self._ctor_kwargs) + infer(template) + infer_global(self, types.Function(template)) + + def __call__(self, *args, **kwargs): + """ + Calls the Python Impl + """ + _, impl = self._defn(None, *args, **kwargs) + return impl(*args, **kwargs) + + def __repr__(self): + return "".format(self._name) + + def __deepcopy__(self, memo): + # NOTE: Intrinsic are immutable and we don't need to copy. + # This is triggered from deepcopy of statements. + return self + + def _reduce_states(self): + """ + NOTE: part of ReduceMixin protocol + """ + return dict(uuid=self._uuid, name=self._name, defn=self._defn) + + @classmethod + def _rebuild(cls, uuid, name, defn): + """ + NOTE: part of ReduceMixin protocol + """ + try: + return cls._memo[uuid] + except KeyError: + llc = cls(name=name, defn=defn) + llc._register() + llc._set_uuid(uuid) + return llc + + +def mlir(*args, **kwargs): + """ + TODO (#1162): update this doc to reflect how we use this! + A decorator marking the decorated function as typing and implementing + *func* in nopython mode using the llvmlite IRBuilder API. This is an escape + hatch for expert users to build custom LLVM IR that will be inlined to + the caller. + + The first argument to *func* is the typing context. The rest of the + arguments corresponds to the type of arguments of the decorated function. + These arguments are also used as the formal argument of the decorated + function. If *func* has the signature ``foo(typing_context, arg0, arg1)``, + the decorated function will have the signature ``foo(arg0, arg1)``. + + The return values of *func* should be a 2-tuple of expected type signature, + and a code-generation function that will passed to ``lower_builtin``. + For unsupported operation, return None. + + Here is an example implementing a ``cast_int_to_byte_ptr`` that cast + any integer to a byte pointer:: + + @intrinsic + def cast_int_to_byte_ptr(typingctx, src): + # check for accepted types + if isinstance(src, types.Integer): + # create the expected type signature + result_type = types.CPointer(types.uint8) + sig = result_type(types.uintp) + # defines the custom code generation + def codegen(context, builder, signature, args): + # llvm IRBuilder code here + [src] = args + rtype = signature.return_type + llrtype = context.get_value_type(rtype) + return builder.inttoptr(src, llrtype) + return sig, codegen + """ + # Make inner function for the actual work + def _mlir(func): + name = getattr(func, '__name__', str(func)) + llc = _MLIR(name, func, **kwargs) + llc._register() + return llc + + if not kwargs: + # No option is given + return _mlir(*args) + else: + # options are given, create a new callable to recv the + # definition function + def wrapper(func): + return _mlir(func) + return wrapper diff --git a/heir_py/example.py b/heir_py/example.py index 6f2ce797b..dc98f1261 100644 --- a/heir_py/example.py +++ b/heir_py/example.py @@ -1,22 +1,106 @@ """Example of HEIR Python usage.""" -from heir_py import pipeline +from heir_py import compile +from heir_py.mlir import * +from heir_py.types import * +# FIXME: Also add the tensorflow-to-tosa-to-HEIR example in here, even it doesn't use Numba +# TODO (#1162): Allow manually specifying precision/scale and warn/error if not possible! +# TODO (#????): Add precision computation/check to Mgmt dialect/infrastructure +# TODO (#1119): expose ctxt serialization in python +# TODO (#1162) : Fix "ImportError: generic_type: type "PublicKey" is already registered!" when doing setup twice +# TODO (#????): In OpenFHE is it more efficient to do mul(...) WITH relin than to do mul_no_relin and then relin? Write a simple peephhole opt to rewrite this? +# TODO (HECO OPT): Do not touch loops that are already operating on tensors/SIMD values -def foo(a, b): - """An example function.""" - return a * a - b * b +# ### Simple Example +# # TODO (#????): It seems like our mgmgt pass doesn't actively manage modulus differences for additions, relying on OpenFHE to handle that +# @compile() # defaults to scheme="bgv"", backend="openfhe", debug=False +# def foo(x : Secret[I16], y : Secret[I16]): +# sum = x + y +# diff = x - y +# mul = x * y +# expression = sum * diff + mul +# deadcode = expression * mul +# return expression +# foo.setup() # runs keygen/etc +# enc_x = foo.encrypt_x(7) +# enc_y = foo.encrypt_y(8) +# result_enc = foo.eval(enc_x, enc_y) +# result = foo.decrypt_result(result_enc) +# print(f"Expected result for `foo`: {foo(7,8)}, decrypted result: {result}") -# to replace with decorator -_heir_foo = pipeline.run_compiler(foo) -cc = _heir_foo.foo__generate_crypto_context() -kp = cc.KeyGen() -_heir_foo.foo__configure_crypto_context(cc, kp.secretKey) -arg0_enc = _heir_foo.foo__encrypt__arg0(cc, 7, kp.publicKey) -arg1_enc = _heir_foo.foo__encrypt__arg1(cc, 8, kp.publicKey) -res_enc = _heir_foo.foo(cc, arg0_enc, arg1_enc) -res = _heir_foo.foo__decrypt__result0(cc, res_enc, kp.secretKey) -print(res) # should be -15 + +# ### CKKS Example +# @compile(scheme="ckks") # other options default to backend="openfhe", debug=False +# def bar(x : Secret[F32], y : Secret[F32]): +# sum = x + y +# diff = x - y +# mul = x * y +# expression = sum * diff + mul +# deadcode = expression * mul +# return expression + +# bar.setup() # runs keygen/etc +# enc_x = bar.encrypt_x(7) +# enc_y = bar.encrypt_y(8) +# result_enc = bar.eval(enc_x, enc_y) +# result = bar.decrypt_result(result_enc) +# print(f"Expected result for `bar`: {bar(7,8)}, decrypted result: {result}") + + + + +# ### Ciphertext-Plaintext Example +# @compile(debug=True) + +# def baz(x: Secret[I16], y : I16): +# ptxt_mul = x * y +# ctxt_mul = x * x +# return ptxt_mul + ctxt_mul + +# baz.setup() # runs keygen/etc +# enc_x = baz.encrypt_x(7) +# result_enc = baz.eval(enc_x, 8) +# result = baz.decrypt_result(result_enc) +# print(f"Expected result for `baz`: {baz(7,8)}, decrypted result: {result}") + + +### Ciphertext-Plaintext Example 2 +@compile(debug=True) +def baz2(x: Secret[I16], y : Secret[I16], z : I16): + ptxt_mul = x * z + ctxt_mul = x * x + ctxt_mul2 = y * y + add = ctxt_mul + ctxt_mul2 + return add + ptxt_mul + # TODO (#1284): Relin Opt works if ptxt_mul is RHS, but not if ptxt_mul is LHS? + +baz2.setup() # runs keygen/etc +enc_x = baz2.encrypt_x(7) +enc_y = baz2.encrypt_y(8) +result_enc = baz2.eval(enc_x, enc_y, 9) +result = baz2.decrypt_result(result_enc) +print(f"Expected result for `baz2`: {baz(7,8,9)}, decrypted result: {result}") + + +# ### Matmul Exampl +# # secret.secret> +# @compile(scheme='ckks', debug=True) +# def qux(a : Secret[Tensor[4,4,F32]], b : Secret[Tensor[4,4,F32]]): +# AB = matmul(a,b) +# AABB = matmul(a+a, b+b) +# return AB + AABB + +# a = np.array([[1,2],[3,4]]) +# b = np.array([[5,6],[7,8]]) +# print(qux(a,b)) + +# qux.setup() +# enc_a = qux.encrypt_a(a) +# enc_b = qux.encrypt_b(b) +# result_enc = qux.eval(enc_a, enc_b) +# result = qux.decrypt_result(result_enc) +# print(f"Expected result: {np.matmul(a,b)}, decrypted result: {result}") diff --git a/heir_py/mlir.py b/heir_py/mlir.py new file mode 100644 index 000000000..3a1fab253 --- /dev/null +++ b/heir_py/mlir.py @@ -0,0 +1,9 @@ +from .decorator import mlir +from numba.core import sigutils +import numpy as np + +## Example of an overload: matmul +@mlir +def matmul(typingctx, X, Y): + # TODO (#1162): add a check if input types are valid! + return sigutils._parse_signature_string("float32[:,:](float32[:,:],float32[:,:])"), np.matmul diff --git a/heir_py/mlir_emitter.py b/heir_py/mlir_emitter.py index 03324e96e..92b17ff31 100644 --- a/heir_py/mlir_emitter.py +++ b/heir_py/mlir_emitter.py @@ -4,26 +4,61 @@ import textwrap from numba.core import ir +from numba.core import types +def mlirType(numba_type): + if isinstance(numba_type, types.Integer): + #TODO (#1162): fix handling of signedness + # Since `arith` only allows signless integers, we ignore signedness here. + return "i" + str(numba_type.bitwidth) + if isinstance(numba_type, types.Boolean): + return "i1" + if isinstance(numba_type, types.Float): + return "f" + str(numba_type.bitwidth) + if isinstance(numba_type, types.Complex): + return "complex<" + str(numba_type.bitwidth) + ">" + if isinstance(numba_type, types.Array): + #TODO (#1162): implement support for statically sized tensors + # this probably requires extending numba with a new type + # See https://numba.readthedocs.io/en/stable/extending/index.html + return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" + raise NotImplementedError("Unsupported type: " + str(numba_type)) + +def arithSuffix(numba_type): + if isinstance(numba_type, types.Integer): + return "i" + if isinstance(numba_type, types.Boolean): + return "i" + if isinstance(numba_type, types.Float): + return "f" + if isinstance(numba_type, types.Complex): + raise NotImplementedError("Complex numbers not supported in `arith` dialect") + if isinstance(numba_type, types.Array): + return arithSuffix(numba_type.dtype) + raise NotImplementedError("Unsupported type: " + str(numba_type)) -class TextualMlirEmitter: - def __init__(self, ssa_ir): +class TextualMlirEmitter: + def __init__(self, ssa_ir, secret_args, typemap, retty): self.ssa_ir = ssa_ir + self.secret_args = secret_args + self.typemap = typemap + self.retty = retty, self.temp_var_id = 0 self.numba_names_to_ssa_var_names = {} self.globals_map = {} def emit(self): func_name = self.ssa_ir.func_id.func_name + secret_flag = " {secret.secret}" # probably should use unique name... # func_name = ssa_ir.func_id.unique_name + args_str = ", ".join([f"%{name}: {mlirType(self.typemap.get(name))}{secret_flag if idx in self.secret_args else str()}" for idx, name in enumerate(self.ssa_ir.arg_names)]) - # TODO(#1162): use inferred or explicit types for args - args_str = ", ".join([f"%{name}: i64" for name in self.ssa_ir.arg_names]) - - # TODO(#1162): get inferred or explicit return types - return_types_str = "i64" + # TODO(#1162): support multiple return values! + if(len(self.retty) > 1): + raise NotImplementedError("Multiple return values not supported") + return_types_str = mlirType(self.retty[0]) body = self.emit_body() @@ -77,6 +112,11 @@ def get_or_create_name(self, var): self.temp_var_id += 1 return f"%{ssa_id}" + def get_next_name(self): + ssa_id = self.temp_var_id + self.temp_var_id += 1 + return f"%{ssa_id}" + def get_name(self, var): assert var.name in self.numba_names_to_ssa_var_names return self.get_or_create_name(var) @@ -100,8 +140,7 @@ def emit_assign(self, assign): case ir.Expr(op="binop"): name = self.get_or_create_name(assign.target) emitted_expr = self.emit_binop(assign.value) - # TODO(#1162): replace i64 with inferred type - return f"{name} = {emitted_expr} : i64" + return f"{name} = {emitted_expr} : " + mlirType(self.typemap.get(assign.target.name)) case ir.Expr(op="call"): func = assign.value.func # if assert fails, variable was undefined @@ -110,6 +149,36 @@ def emit_assign(self, assign): # nothing to do, forward the name to the arg of bool() self.forward_name(from_var=assign.target, to_var=assign.value.args[0]) return "" + if self.globals_map[func.name] == "matmul": + # emit `linalg.matmul` operation. + lhs = self.get_name(assign.value.args[0]) + lhs_ty = mlirType(self.typemap.get(assign.value.args[0].name)) + rhs = self.get_name(assign.value.args[1]) + rhs_ty = mlirType(self.typemap.get(assign.value.args[1].name)) + target_numba_type = self.typemap.get(assign.target.name) + out_ty = mlirType(target_numba_type) + name = self.get_or_create_name(assign.target) + if isinstance(target_numba_type, types.Array): + # We need to emit a tensor.empty() operation to create the output tensor + # but before that, we need to emit tensor.dim to get the sizes for the empty tensor + str = "" + dims = [] + for i in range(target_numba_type.ndim): + cst = self.get_next_name() + str += f"{cst} = arith.constant {i} : index\n" + dim = self.get_next_name() + dims.append(dim) + str += f"{dim} = tensor.dim {lhs}, {cst} : {lhs_ty}\n" + + empty = self.get_next_name() + str += f"{empty} = tensor.empty({','.join(dims)}) : {out_ty}\n" + str += f"{name} = linalg.matmul ins({lhs}, {rhs} : {lhs_ty}, {rhs_ty}) outs({empty} : {out_ty}) -> {out_ty}" + return str + else: + #TODO (#1162): implement support for statically sized tensors + # this probably requires extending numba with a new type + # See https://numba.readthedocs.io/en/stable/extending/index.html + raise NotImplementedError(f"Unsupported target type {target_numba_type} for {assign.target.name}.") else: raise NotImplementedError("Unknown global " + func.name) case ir.Expr(op="cast"): @@ -125,7 +194,7 @@ def emit_assign(self, assign): case ir.Global(): self.globals_map[assign.target.name] = assign.value.name return "" - raise NotImplementedError() + raise NotImplementedError(f"Unsupported IR Element: {assign}") def emit_expr(self, expr): if expr.op == "binop": @@ -174,16 +243,18 @@ def emit_expr(self, expr): def emit_binop(self, binop): lhs_ssa = self.get_name(binop.lhs) rhs_ssa = self.get_name(binop.rhs) + # This should be the same, otherwise MLIR will complain + suffix = arithSuffix(self.typemap.get(str(binop.lhs))) match binop.fn: case operator.lt: - return f"arith.cmpi slt, {lhs_ssa}, {rhs_ssa}" + return f"arith.cmp{suffix} slt, {lhs_ssa}, {rhs_ssa}" case operator.add: - return f"arith.addi {lhs_ssa}, {rhs_ssa}" + return f"arith.add{suffix} {lhs_ssa}, {rhs_ssa}" case operator.mul: - return f"arith.muli {lhs_ssa}, {rhs_ssa}" + return f"arith.mul{suffix} {lhs_ssa}, {rhs_ssa}" case operator.sub: - return f"arith.subi {lhs_ssa}, {rhs_ssa}" + return f"arith.sub{suffix} {lhs_ssa}, {rhs_ssa}" raise NotImplementedError("Unsupported binop: " + binop.fn.__name__) @@ -193,5 +264,4 @@ def emit_branch(self, branch): def emit_return(self, ret): var = self.get_name(ret.value) - # TODO(#1162): replace i64 with inferred or explicit return type - return f"func.return {var} : i64" + return f"func.return {var} : " + mlirType(self.typemap.get(str(ret.value))) diff --git a/heir_py/numba_type_inference_example.py b/heir_py/numba_type_inference_example.py new file mode 100644 index 000000000..e306dbe9e --- /dev/null +++ b/heir_py/numba_type_inference_example.py @@ -0,0 +1,25 @@ +from numba.core.registry import cpu_target +from numba.core import compiler, sigutils +from numba.core.typed_passes import type_inference_stage + +# Define a test function +def example_function(x, y): + z = x + y + return z + +sig_string = "int16(int16, int16)" + +test_ir = compiler.run_frontend(example_function) +typingctx = cpu_target.typing_context +targetctx = cpu_target.target_context +typingctx.refresh() +targetctx.refresh() + +fn_args, fn_retty = sigutils.normalize_signature(sig_string) +typing_res = type_inference_stage(typingctx, targetctx, test_ir, fn_args, + None) + +# Get inferred types +typemap = typing_res.typemap +for var, typ in typemap.items(): + print(f"Variable: {var}, Type: {typ}") diff --git a/heir_py/pipeline.py b/heir_py/pipeline.py index c9520ac48..f7ad697d2 100644 --- a/heir_py/pipeline.py +++ b/heir_py/pipeline.py @@ -13,8 +13,15 @@ from heir_py import mlir_emitter from heir_py import openfhe_config as openfhe_config_lib from heir_py import pybind_helpers -from numba.core import bytecode -from numba.core import interpreter +from heir_py.types import MLIRTypeAnnotation, Secret, Tensor + +# FIXME: Don't use implementation detail _GenericAlias!!! +from typing import get_args, get_origin, _GenericAlias + +from numba.core import compiler +from numba.core import sigutils +from numba.core.registry import cpu_target +from numba.core.typed_passes import type_inference_stage dataclass = dataclasses.dataclass Path = pathlib.Path @@ -35,6 +42,9 @@ class CompilationResult: # A list of arg names (in order) arg_names: list[str] + # A list of indices of secret args + secret_args: list[int] + # A mapping from argument name to the compiled encryption function arg_enc_funcs: dict[str, object] @@ -48,11 +58,36 @@ class CompilationResult: setup_funcs: dict[str, object] +def parse_signature(idx, arg_type): + if (isinstance(arg_type, _GenericAlias)): + wrapper = get_origin(arg_type) + # print(f"Found an object: {arg_type} with type {type(arg_type)}") + # print(f"Wrapper is {wrapper} with type {type(wrapper)}") + # print(f"Typing.get_args tell us: {get_args(arg_type)}") + if(issubclass(wrapper, Secret)): + signature += f"{get_args(arg_type)[0].numba_str()}," + secret_args.append(idx) + elif(issubclass(wrapper, Tensor)): + args = get_args(arg_type) + inner_type = args[len(args) - 1].numba_str() + #FIXME: Add support for static tensor sizes! + signature += f"{inner_type}[{','.join([':'] * (len(args) - 1))}]" + raise NotImplementedError("Static tensor sizes are not yet supported") + else: + raise ValueError(f"Unsupported type annotation {arg_type}") + + elif (not issubclass(arg_type, MLIRTypeAnnotation)): + raise ValueError(f"Unsupported type annotation {arg_type}") + else: + signature += f"{arg_type.numba_str()}," + return signature + def run_compiler( function, + scheme, openfhe_config: OpenFHEConfig = openfhe_config_lib.DEFAULT_INSTALLED_OPENFHE_CONFIG, heir_config: HEIRConfig = heir_config_lib.DEVELOPMENT_HEIR_CONFIG, - debug=False, + debug: bool = False, ): """Run the compiler.""" # The temporary workspace dir is so that heir-opt, heir-translate, and @@ -66,21 +101,71 @@ def run_compiler( # tempfile.mkdtemp()` and manually clean it up. workspace_dir = tempfile.mkdtemp() try: - func_id = bytecode.FunctionIdentity.from_function(function) - converted_bytecode = bytecode.ByteCode(func_id) - ssa_ir = interpreter.Interpreter(func_id).interpret(converted_bytecode) - mlir_textual = mlir_emitter.TextualMlirEmitter(ssa_ir).emit() - func_name = func_id.func_name + ssa_ir = compiler.run_frontend(function) + + ##### (Numba) Type Inference + # Fetch the function's type annotation + annotation = function.__annotations__ + # Convert those annotations back to a numba signature + signature = "" + secret_args = [] + for idx, (_, arg_type) in enumerate(annotation.items()): + if (isinstance(arg_type, _GenericAlias)): + wrapper = get_origin(arg_type) + # print(f"Found an object: {arg_type} with type {type(arg_type)}") + # print(f"Wrapper is {wrapper} with type {type(wrapper)}") + # print(f"Typing.get_args tell us: {get_args(arg_type)}") + if(issubclass(wrapper, Secret)): + inner_type = get_args(arg_type)[0] + if(isinstance(inner_type, _GenericAlias)): + # print("In the secret-of-another-wrapper case") + args = get_args(inner_type) + element_type = args[len(args) - 1].numba_str() + #FIXME: Add support for static tensor sizes! + signature += f"{element_type}[{','.join([':'] * (len(args) - 1))}]," + else: + signature += f"{get_args(arg_type)[0].numba_str()}," + secret_args.append(idx) + elif(issubclass(wrapper, Tensor)): + args = get_args(arg_type) + inner_type = args[len(args) - 1].numba_str() + #FIXME: Add support for static tensor sizes! + signature += f"{inner_type}[{','.join([':'] * (len(args) - 1))}]" + else: + raise ValueError(f"Unsupported type annotation {arg_type}") + + elif (not issubclass(arg_type, MLIRTypeAnnotation)): + raise ValueError(f"Unsupported type annotation {arg_type}") + else: + signature += f"{arg_type.numba_str()}," + + # Set up inference contexts + typingctx = cpu_target.typing_context + targetctx = cpu_target.target_context + typingctx.refresh() + targetctx.refresh() + fn_args, fn_retty = sigutils.normalize_signature(signature) + # Run actual inference. TODO(#1162): handle type inference errors + typemap, restype, calltypes, errs = type_inference_stage(typingctx, targetctx, ssa_ir, fn_args, + None) + + mlir_textual = mlir_emitter.TextualMlirEmitter(ssa_ir, secret_args, typemap, restype).emit() + func_name = ssa_ir.func_id.func_name module_name = f"_heir_{func_name}" + if(debug): + mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" + print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}") + with open(mlir_in_filepath, "w") as f: + f.write(mlir_textual) + heir_opt = heir_backend.HeirOptBackend(heir_config.heir_opt_path) # TODO(#1162): construct heir-opt pipeline options from decorator heir_opt_options = [ - f"--secretize=function={func_name}", ( - "--mlir-to-openfhe-bgv=" + f"--mlir-to-openfhe-{scheme}=" f"entry-function={func_name} ciphertext-degree=32" - ), + ) ] heir_opt_output = heir_opt.run_binary( input=mlir_textual, @@ -88,10 +173,6 @@ def run_compiler( ) if(debug): - mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" - print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}") - with open(mlir_in_filepath, "w") as f: - f.write(mlir_textual) mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir" print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}") with open(mlir_out_filepath, "w") as f: @@ -105,10 +186,12 @@ def run_compiler( pybind_filepath = Path(workspace_dir) / f"{func_name}_bindings.cpp" # TODO(#1162): construct heir-translate pipeline options from decorator include_type_flag = "--openfhe-include-type=" + openfhe_config.include_type + scheme_flag = f"--openfhe-scheme={scheme}" heir_translate.run_binary( input=heir_opt_output, options=[ "--emit-openfhe-pke-header", + scheme_flag, include_type_flag, "-o", h_filepath, @@ -116,7 +199,7 @@ def run_compiler( ) heir_translate.run_binary( input=heir_opt_output, - options=["--emit-openfhe-pke", include_type_flag, "-o", cpp_filepath], + options=["--emit-openfhe-pke", scheme_flag, include_type_flag, "-o", cpp_filepath], ) heir_translate.run_binary( input=heir_opt_output, @@ -174,10 +257,12 @@ def run_compiler( result = CompilationResult( module=bound_module, func_name=func_name, - arg_names=func_id.arg_names, + secret_args=secret_args, + arg_names=ssa_ir.func_id.arg_names, arg_enc_funcs={ arg_name: getattr(bound_module, f"{func_name}__encrypt__arg{i}") - for i, arg_name in enumerate(func_id.arg_names) + for i, arg_name in enumerate(ssa_ir.func_id.arg_names) + if i in secret_args }, result_dec_func=getattr(bound_module, f"{func_name}__decrypt__result0"), main_func=getattr(bound_module, func_name), diff --git a/heir_py/pipeline_test.py b/heir_py/pipeline_test.py index 569f04ced..5c311b266 100644 --- a/heir_py/pipeline_test.py +++ b/heir_py/pipeline_test.py @@ -13,6 +13,7 @@ def foo(a, b): heir_foo = pipeline.run_compiler( foo, + signature="int32(int32, int32)", openfhe_config=openfhe_config.from_os_env(), heir_config=heir_config.from_os_env(), ).module diff --git a/heir_py/types.py b/heir_py/types.py new file mode 100644 index 000000000..87c98a65f --- /dev/null +++ b/heir_py/types.py @@ -0,0 +1,67 @@ +from typing import TypeVar, TypeVarTuple, Generic +from numba import types +from numba.extending import typeof_impl, as_numba_type, type_callable + +# (Mostly Dummy) Classes for Type Annotations + +T = TypeVar('T') +Ts = TypeVarTuple("Ts") + +class MLIRTypeAnnotation: + def numba_str(): + raise NotImplementedError("No numba type exists for MLIRTypeAnnotation") + + +class Secret(Generic[T], MLIRTypeAnnotation): + def numba_str(): + raise NotImplementedError("No numba type exists for Secret") + +class Tensor(Generic[*Ts], MLIRTypeAnnotation): + def numba_str(): + raise NotImplementedError("No numba type exists for Tensor") + +class F32(MLIRTypeAnnotation): + def numba_str(): + return "float32" + +class F64(MLIRTypeAnnotation): + def numba_str(): + return "float64" + +class I1(MLIRTypeAnnotation): + def numba_str(): + return "bool" + +class I4(MLIRTypeAnnotation): + def numba_str(): + return "int4" + +class I8(MLIRTypeAnnotation): + def numba_str(): + return "int8" + + +class I16(MLIRTypeAnnotation): + def numba_str(): + return "int16" + +class I32(MLIRTypeAnnotation): + def numba_str(): + return "int32" +class I64(MLIRTypeAnnotation): + def numba_str(): + return "int64" + + +# # Numba Types +# class SecretType(types.Type): +# def __init__(self): +# super(SecretType, self).__init__(name="Secret") + +# secret_type = SecretType() + +# @typeof_impl.register(Secret) +# def typeof_index(val, c): +# return secret_type + +# as_numba_type.register(Secret, SecretType) diff --git a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp index 408173e9b..85e30f9b3 100644 --- a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp @@ -170,6 +170,7 @@ struct SecretToBGV : public impl::SecretToBGVBase { if (failed(rlweRing)) { return signalPassFailure(); } + // Ensure that all secret types are uniform and matching the ring // parameter size. Operation *foundOp = walkAndDetect(module, [&](Operation *op) {