Skip to content

Commit

Permalink
use type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderViand-Intel committed Jan 17, 2025
1 parent b3bf2d0 commit 704f176
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 48 deletions.
4 changes: 0 additions & 4 deletions heir_py/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __call__(self, *args, **kwargs):


def compile(
signature : str,
backend: str = "openfhe",
backend_config: Optional[openfhe_config.OpenFHEConfig] = openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG,
heir_config: Optional[_heir_config.HEIRConfig] = _heir_config.DEVELOPMENT_HEIR_CONFIG,
Expand All @@ -143,8 +142,6 @@ def compile(
"""Compile a function to its private equivalent in FHE.
Args:
signature: a Numba signature string (See
https://numba.readthedocs.io/en/stable/reference/types.html#signatures).
backend: a string indicating the backend to use. Options: 'openfhe'
(default).
backend_config: a config object to control system-specific paths for the
Expand All @@ -159,7 +156,6 @@ def compile(
def decorator(func):
compilation_result = run_compiler(
func,
signature,
openfhe_config=backend_config or openfhe_config.from_os_env(),
heir_config=heir_config or _heir_config.from_os_env(),
debug = debug
Expand Down
13 changes: 8 additions & 5 deletions heir_py/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from heir_py import compile
from heir_py.mlir import *
from heir_py.types import *

# FIXME: Also add the tensorflow-to-tosa-to-HEIR example in here, even it doesn't use Numba

### Simple Example
# TODO (#1162): Add `scheme` kwarg and add CKKS pipeline (note: addi -> addf)
# TODO (#1162): Allow manually specifying precision/scale and warn/error if not possible!
# TODO (#????): Add precision computation/check to Mgmt dialect/infrastructure
@compile("int16(int16,int16)", backend="openfhe", debug=True)
def foo(x : Secret[I64], y : Secret[I64]):
@compile(backend="openfhe", debug=True)
def foo(x : I16, y : I16):
sum = x + y
diff = x - y
mul = x * y
Expand All @@ -27,9 +28,11 @@ def foo(x : Secret[I64], y : Secret[I64]):


# # Matmul Example
# @compile("float32[:,:](float32[:,:],float32[:,:])", backend="openfhe", debug=False)
# def goo(a, b):
# return matmul(a,b)
# @compile("float32[:,:](float32[:,:],float32[:,:])", backend="openfhe", debug=True)
# def goo(a : Secret[Tensor[4,4,F32]], b : Secret[Tensor[4,4,F32]]):
# AB = matmul(a,b)
# AABB = matmul(a+a, b+b)
# return AB + AABB

# a = np.array([[1,2],[3,4]])
# b = np.array([[5,6],[7,8]])
Expand Down
35 changes: 0 additions & 35 deletions heir_py/mlir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import TypeVar, Generic
from .decorator import mlir
from numba.core import sigutils
import numpy as np
Expand All @@ -8,37 +7,3 @@
def matmul(typingctx, X, Y):
# TODO (#1162): add a check if input types are valid!
return sigutils._parse_signature_string("float32[:,:](float32[:,:],float32[:,:])"), np.matmul


# For Type Annotations

T = TypeVar('T')

class Secret(Generic[T]):
pass

class Integer:
pass

class F32:
pass

class F64:
pass

class I1:
pass

class I4:
pass

class I8:
pass

class I16:
pass

class I32:
pass
class I64:
pass
17 changes: 13 additions & 4 deletions heir_py/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from heir_py import mlir_emitter
from heir_py import openfhe_config as openfhe_config_lib
from heir_py import pybind_helpers
from heir_py.types import MLIRTypeAnnotation

from numba.core import compiler
from numba.core import sigutils
from numba.core.registry import cpu_target
Expand Down Expand Up @@ -52,7 +54,6 @@ class CompilationResult:

def run_compiler(
function,
signature: str,
openfhe_config: OpenFHEConfig = openfhe_config_lib.DEFAULT_INSTALLED_OPENFHE_CONFIG,
heir_config: HEIRConfig = heir_config_lib.DEVELOPMENT_HEIR_CONFIG,
debug: bool = False,
Expand All @@ -71,14 +72,22 @@ def run_compiler(
try:
ssa_ir = compiler.run_frontend(function)

# Numba Type Inference
##### (Numba) Type Inference
# Fetch the function's type annotation
annotation = function.__annotations__
# Convert those annotations back to a numba signature
signature = ""
for _, arg_type in annotation.items():
if (not issubclass(arg_type, MLIRTypeAnnotation)):
raise ValueError(f"Unsupported type annotation {arg_type}")
signature += f"{arg_type.numba_str()},"
# Set up inference contexts
typingctx = cpu_target.typing_context
targetctx = cpu_target.target_context
typingctx.refresh()
targetctx.refresh()
#TODO(#1162): make use of return type in signature?
fn_args, fn_retty = sigutils.normalize_signature(signature)
#TODO(#1162): handle type inference errors
# Run actual inference. TODO(#1162): handle type inference errors
typemap, restype, calltypes, errs = type_inference_stage(typingctx, targetctx, ssa_ir, fn_args,
None)

Expand Down
49 changes: 49 additions & 0 deletions heir_py/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import TypeVar, Generic
# Type Annotations

T = TypeVar('T')

class MLIRTypeAnnotation:
def numba_str():
raise NotImplementedError("No numba type exists for MLIRTypeAnnotation")


class Secret(Generic[T], MLIRTypeAnnotation):
def numba_str():
raise NotImplementedError("No numba type exists for Secret")

class Tensor(Generic[T], MLIRTypeAnnotation):
def numba_str():
raise NotImplementedError("No numba type exists for Tensor")

class F32(MLIRTypeAnnotation):
def numba_str():
return "float32"

class F64(MLIRTypeAnnotation):
def numba_str():
return "float64"

class I1(MLIRTypeAnnotation):
def numba_str():
return "bool"

class I4(MLIRTypeAnnotation):
def numba_str():
return "int4"

class I8(MLIRTypeAnnotation):
def numba_str():
return "int8"


class I16(MLIRTypeAnnotation):
def numba_str():
return "int16"

class I32(MLIRTypeAnnotation):
def numba_str():
return "int32"
class I64(MLIRTypeAnnotation):
def numba_str():
return "int64"

0 comments on commit 704f176

Please sign in to comment.