diff --git a/heir_py/decorator.py b/heir_py/decorator.py index 0e4ff0434..8515a1f7d 100644 --- a/heir_py/decorator.py +++ b/heir_py/decorator.py @@ -134,7 +134,6 @@ def __call__(self, *args, **kwargs): def compile( - signature : str, backend: str = "openfhe", backend_config: Optional[openfhe_config.OpenFHEConfig] = openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG, heir_config: Optional[_heir_config.HEIRConfig] = _heir_config.DEVELOPMENT_HEIR_CONFIG, @@ -143,8 +142,6 @@ def compile( """Compile a function to its private equivalent in FHE. Args: - signature: a Numba signature string (See - https://numba.readthedocs.io/en/stable/reference/types.html#signatures). backend: a string indicating the backend to use. Options: 'openfhe' (default). backend_config: a config object to control system-specific paths for the @@ -159,7 +156,6 @@ def compile( def decorator(func): compilation_result = run_compiler( func, - signature, openfhe_config=backend_config or openfhe_config.from_os_env(), heir_config=heir_config or _heir_config.from_os_env(), debug = debug diff --git a/heir_py/example.py b/heir_py/example.py index ef105a428..7595e22d6 100644 --- a/heir_py/example.py +++ b/heir_py/example.py @@ -2,6 +2,7 @@ from heir_py import compile from heir_py.mlir import * +from heir_py.types import * # FIXME: Also add the tensorflow-to-tosa-to-HEIR example in here, even it doesn't use Numba @@ -9,8 +10,8 @@ # TODO (#1162): Add `scheme` kwarg and add CKKS pipeline (note: addi -> addf) # TODO (#1162): Allow manually specifying precision/scale and warn/error if not possible! # TODO (#????): Add precision computation/check to Mgmt dialect/infrastructure -@compile("int16(int16,int16)", backend="openfhe", debug=True) -def foo(x : Secret[I64], y : Secret[I64]): +@compile(backend="openfhe", debug=True) +def foo(x : I16, y : I16): sum = x + y diff = x - y mul = x * y @@ -27,9 +28,11 @@ def foo(x : Secret[I64], y : Secret[I64]): # # Matmul Example -# @compile("float32[:,:](float32[:,:],float32[:,:])", backend="openfhe", debug=False) -# def goo(a, b): -# return matmul(a,b) +# @compile("float32[:,:](float32[:,:],float32[:,:])", backend="openfhe", debug=True) +# def goo(a : Secret[Tensor[4,4,F32]], b : Secret[Tensor[4,4,F32]]): +# AB = matmul(a,b) +# AABB = matmul(a+a, b+b) +# return AB + AABB # a = np.array([[1,2],[3,4]]) # b = np.array([[5,6],[7,8]]) diff --git a/heir_py/mlir.py b/heir_py/mlir.py index 4ee745293..3a1fab253 100644 --- a/heir_py/mlir.py +++ b/heir_py/mlir.py @@ -1,4 +1,3 @@ -from typing import TypeVar, Generic from .decorator import mlir from numba.core import sigutils import numpy as np @@ -8,37 +7,3 @@ def matmul(typingctx, X, Y): # TODO (#1162): add a check if input types are valid! return sigutils._parse_signature_string("float32[:,:](float32[:,:],float32[:,:])"), np.matmul - - -# For Type Annotations - -T = TypeVar('T') - -class Secret(Generic[T]): - pass - -class Integer: - pass - -class F32: - pass - -class F64: - pass - -class I1: - pass - -class I4: - pass - -class I8: - pass - -class I16: - pass - -class I32: - pass -class I64: - pass diff --git a/heir_py/pipeline.py b/heir_py/pipeline.py index d89ef5e92..3fb4a076f 100644 --- a/heir_py/pipeline.py +++ b/heir_py/pipeline.py @@ -13,6 +13,8 @@ from heir_py import mlir_emitter from heir_py import openfhe_config as openfhe_config_lib from heir_py import pybind_helpers +from heir_py.types import MLIRTypeAnnotation + from numba.core import compiler from numba.core import sigutils from numba.core.registry import cpu_target @@ -52,7 +54,6 @@ class CompilationResult: def run_compiler( function, - signature: str, openfhe_config: OpenFHEConfig = openfhe_config_lib.DEFAULT_INSTALLED_OPENFHE_CONFIG, heir_config: HEIRConfig = heir_config_lib.DEVELOPMENT_HEIR_CONFIG, debug: bool = False, @@ -71,14 +72,22 @@ def run_compiler( try: ssa_ir = compiler.run_frontend(function) - # Numba Type Inference + ##### (Numba) Type Inference + # Fetch the function's type annotation + annotation = function.__annotations__ + # Convert those annotations back to a numba signature + signature = "" + for _, arg_type in annotation.items(): + if (not issubclass(arg_type, MLIRTypeAnnotation)): + raise ValueError(f"Unsupported type annotation {arg_type}") + signature += f"{arg_type.numba_str()}," + # Set up inference contexts typingctx = cpu_target.typing_context targetctx = cpu_target.target_context typingctx.refresh() targetctx.refresh() - #TODO(#1162): make use of return type in signature? fn_args, fn_retty = sigutils.normalize_signature(signature) - #TODO(#1162): handle type inference errors + # Run actual inference. TODO(#1162): handle type inference errors typemap, restype, calltypes, errs = type_inference_stage(typingctx, targetctx, ssa_ir, fn_args, None) diff --git a/heir_py/types.py b/heir_py/types.py new file mode 100644 index 000000000..fbd3f5b82 --- /dev/null +++ b/heir_py/types.py @@ -0,0 +1,49 @@ +from typing import TypeVar, Generic +# Type Annotations + +T = TypeVar('T') + +class MLIRTypeAnnotation: + def numba_str(): + raise NotImplementedError("No numba type exists for MLIRTypeAnnotation") + + +class Secret(Generic[T], MLIRTypeAnnotation): + def numba_str(): + raise NotImplementedError("No numba type exists for Secret") + +class Tensor(Generic[T], MLIRTypeAnnotation): + def numba_str(): + raise NotImplementedError("No numba type exists for Tensor") + +class F32(MLIRTypeAnnotation): + def numba_str(): + return "float32" + +class F64(MLIRTypeAnnotation): + def numba_str(): + return "float64" + +class I1(MLIRTypeAnnotation): + def numba_str(): + return "bool" + +class I4(MLIRTypeAnnotation): + def numba_str(): + return "int4" + +class I8(MLIRTypeAnnotation): + def numba_str(): + return "int8" + + +class I16(MLIRTypeAnnotation): + def numba_str(): + return "int16" + +class I32(MLIRTypeAnnotation): + def numba_str(): + return "int32" +class I64(MLIRTypeAnnotation): + def numba_str(): + return "int64"