Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python Frontend: Use Numba Type Inference #1253

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7d1d7d9
use compiler.run_frontend instead of manual interpreter call
AlexanderViand-Intel Jan 21, 2025
28cc244
debug: print input MLIR location before calling HEIR (in case of fail…
AlexanderViand-Intel Jan 21, 2025
ff05fdd
make decorator default to installed OpenFHE, and development HEIR config
AlexanderViand-Intel Jan 21, 2025
49a5a74
TMP? switch example to decorator
AlexanderViand-Intel Jan 21, 2025
6f9f3fe
TMP: add file showing numba type inference APIs
AlexanderViand-Intel Jan 21, 2025
238e2c7
add numba type inference to pipeline
AlexanderViand-Intel Jan 21, 2025
1653e01
fix for changing to compiler.run_frontend
AlexanderViand-Intel Jan 21, 2025
a6e4f63
use type inference in MLIR emitter
AlexanderViand-Intel Jan 21, 2025
80a728a
WIP: matmul and inference
AlexanderViand-Intel Jan 21, 2025
301d3b3
add "support" for numba array
AlexanderViand-Intel Jan 21, 2025
9a41e86
allow `foo.eval` in addition to `foo.foo`
AlexanderViand-Intel Jan 21, 2025
08e9292
split examples into basic and matmul
AlexanderViand-Intel Jan 21, 2025
5b3156b
WIP: @mlir decorator
AlexanderViand-Intel Jan 21, 2025
091be2c
add error for unranked tensors
AlexanderViand-Intel Jan 21, 2025
393c982
nevermind
AlexanderViand-Intel Jan 21, 2025
d4bc3e3
WIP
AlexanderViand-Intel Jan 21, 2025
8428e09
feedback/TODOs from meeting
AlexanderViand-Intel Jan 21, 2025
eb2d495
WIP
AlexanderViand-Intel Jan 21, 2025
f7a6e19
.
AlexanderViand-Intel Jan 21, 2025
1b164d9
emit arith suffix (addi/addf) based on type
AlexanderViand-Intel Jan 21, 2025
4768c08
use type annotations
AlexanderViand-Intel Jan 21, 2025
520071f
WIP: numba types test
AlexanderViand-Intel Jan 21, 2025
e4dbe7e
add ckks support in python frontend
AlexanderViand-Intel Jan 21, 2025
e412eb1
.
AlexanderViand-Intel Jan 21, 2025
8fb39b9
adds support for "Secret" annotation
AlexanderViand-Intel Jan 21, 2025
67b1cc6
update example file
AlexanderViand-Intel Jan 21, 2025
6277774
WIP: tensor support in annotations
AlexanderViand-Intel Jan 21, 2025
c64076c
.
AlexanderViand-Intel Jan 21, 2025
241eed8
.
AlexanderViand-Intel Jan 21, 2025
5f5127e
notes and todos from meeting
AlexanderViand-Intel Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 170 additions & 5 deletions heir_py/decorator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
"""The decorator entry point for the frontend."""

import os
import uuid
import weakref
import collections
import functools

import numba
from numba.core import types, errors, utils, config

from abc import ABC, abstractmethod
from typing import Optional

from heir_py import heir_config as _heir_config
from heir_py import openfhe_config
from heir_py.pipeline import CompilationResult, run_compiler

# Exported symbols
from numba.core.typing.typeof import typeof_impl # noqa: F401
from numba.core.typing.asnumbatype import as_numba_type # noqa: F401
from numba.core.typing.templates import infer, infer_getattr # noqa: F401
from numba.core.imputils import ( # noqa: F401
lower_builtin, lower_getattr, lower_getattr_generic, # noqa: F401
lower_setattr, lower_setattr_generic, lower_cast) # noqa: F401
from numba.core.datamodel import models # noqa: F401
from numba.core.datamodel import register_default as register_model # noqa: F401, E501
from numba.core.pythonapi import box, unbox, reflect, NativeValue # noqa: F401
from numba._helperlib import _import_cython_function # noqa: F401
from numba.core.serialize import ReduceMixin


class CompilationResultInterface(ABC):

Expand Down Expand Up @@ -77,7 +99,7 @@ def wrapper(arg, *, crypto_context=None, public_key=None):

return wrapper

if key == self.compilation_result.func_name:
if key == self.compilation_result.func_name or key == "eval":
fn = self.compilation_result.main_func

def wrapper(*args, crypto_context=None):
Expand All @@ -102,19 +124,21 @@ def __call__(self, *args, **kwargs):
)

args_encrypted = [
getattr(self, f"encrypt_{arg_name}")(arg)
for arg_name, arg in zip(arg_names, args)
getattr(self, f"encrypt_{arg_name}")(arg) if i in self.compilation_result.secret_args else arg
for i, (arg_name, arg) in enumerate(zip(arg_names, args))
]

result_encrypted = getattr(self, self.compilation_result.func_name)(
*args_encrypted
)
return self.decrypt_result(result_encrypted)


def compile(
scheme: str = "bgv",
backend: str = "openfhe",
backend_config: Optional[openfhe_config.OpenFHEConfig] = None,
heir_config: Optional[_heir_config.HEIRConfig] = None,
backend_config: Optional[openfhe_config.OpenFHEConfig] = openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG,
heir_config: Optional[_heir_config.HEIRConfig] = _heir_config.DEVELOPMENT_HEIR_CONFIG,
debug : Optional[bool] = False
):
"""Compile a function to its private equivalent in FHE.
Expand All @@ -134,6 +158,7 @@ def compile(
def decorator(func):
compilation_result = run_compiler(
func,
scheme,
openfhe_config=backend_config or openfhe_config.from_os_env(),
heir_config=heir_config or _heir_config.from_os_env(),
debug = debug
Expand All @@ -144,3 +169,143 @@ def decorator(func):
raise ValueError(f"Unknown backend: {backend}")

return decorator



class _MLIR(ReduceMixin):
"""
Dummy callable for _MLIR
"""
_memo = weakref.WeakValueDictionary()
# hold refs to last N functions deserialized, retaining them in _memo
# regardless of whether there is another reference
_recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)

__uuid = None

def __init__(self, name, defn, prefer_literal=False, **kwargs):
self._ctor_kwargs = kwargs
self._name = name
self._defn = defn
self._prefer_literal = prefer_literal
functools.update_wrapper(self, defn)

@property
def _uuid(self):
"""
An instance-specific UUID, to avoid multiple deserializations of
a given instance.

Note this is lazily-generated, for performance reasons.
"""
u = self.__uuid
if u is None:
u = str(uuid.uuid1())
self._set_uuid(u)
return u

def _set_uuid(self, u):
assert self.__uuid is None
self.__uuid = u
self._memo[u] = self
self._recent.append(self)

def _register(self):
# _ctor_kwargs
from numba.core.typing.templates import (make_intrinsic_template,
infer_global)

template = make_intrinsic_template(self, self._defn, self._name,
prefer_literal=self._prefer_literal,
kwargs=self._ctor_kwargs)
infer(template)
infer_global(self, types.Function(template))

def __call__(self, *args, **kwargs):
"""
Calls the Python Impl
"""
_, impl = self._defn(None, *args, **kwargs)
return impl(*args, **kwargs)

def __repr__(self):
return "<intrinsic {0}>".format(self._name)

def __deepcopy__(self, memo):
# NOTE: Intrinsic are immutable and we don't need to copy.
# This is triggered from deepcopy of statements.
return self

def _reduce_states(self):
"""
NOTE: part of ReduceMixin protocol
"""
return dict(uuid=self._uuid, name=self._name, defn=self._defn)

@classmethod
def _rebuild(cls, uuid, name, defn):
"""
NOTE: part of ReduceMixin protocol
"""
try:
return cls._memo[uuid]
except KeyError:
llc = cls(name=name, defn=defn)
llc._register()
llc._set_uuid(uuid)
return llc


def mlir(*args, **kwargs):
"""
TODO (#1162): update this doc to reflect how we use this!
A decorator marking the decorated function as typing and implementing
*func* in nopython mode using the llvmlite IRBuilder API. This is an escape
hatch for expert users to build custom LLVM IR that will be inlined to
the caller.

The first argument to *func* is the typing context. The rest of the
arguments corresponds to the type of arguments of the decorated function.
These arguments are also used as the formal argument of the decorated
function. If *func* has the signature ``foo(typing_context, arg0, arg1)``,
the decorated function will have the signature ``foo(arg0, arg1)``.

The return values of *func* should be a 2-tuple of expected type signature,
and a code-generation function that will passed to ``lower_builtin``.
For unsupported operation, return None.

Here is an example implementing a ``cast_int_to_byte_ptr`` that cast
any integer to a byte pointer::

@intrinsic
def cast_int_to_byte_ptr(typingctx, src):
# check for accepted types
if isinstance(src, types.Integer):
# create the expected type signature
result_type = types.CPointer(types.uint8)
sig = result_type(types.uintp)
# defines the custom code generation
def codegen(context, builder, signature, args):
# llvm IRBuilder code here
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.inttoptr(src, llrtype)
return sig, codegen
"""
# Make inner function for the actual work
def _mlir(func):
name = getattr(func, '__name__', str(func))
llc = _MLIR(name, func, **kwargs)
llc._register()
return llc

if not kwargs:
# No option is given
return _mlir(*args)
else:
# options are given, create a new callable to recv the
# definition function
def wrapper(func):
return _mlir(func)
return wrapper
112 changes: 98 additions & 14 deletions heir_py/example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,106 @@
"""Example of HEIR Python usage."""

from heir_py import pipeline
from heir_py import compile
from heir_py.mlir import *
from heir_py.types import *

# FIXME: Also add the tensorflow-to-tosa-to-HEIR example in here, even it doesn't use Numba
# TODO (#1162): Allow manually specifying precision/scale and warn/error if not possible!
# TODO (#????): Add precision computation/check to Mgmt dialect/infrastructure
# TODO (#1119): expose ctxt serialization in python
# TODO (#1162) : Fix "ImportError: generic_type: type "PublicKey" is already registered!" when doing setup twice
# TODO (#????): In OpenFHE is it more efficient to do mul(...) WITH relin than to do mul_no_relin and then relin? Write a simple peephhole opt to rewrite this?
# TODO (HECO OPT): Do not touch loops that are already operating on tensors/SIMD values

def foo(a, b):
"""An example function."""
return a * a - b * b
# ### Simple Example
# # TODO (#????): It seems like our mgmgt pass doesn't actively manage modulus differences for additions, relying on OpenFHE to handle that
# @compile() # defaults to scheme="bgv"", backend="openfhe", debug=False
# def foo(x : Secret[I16], y : Secret[I16]):
# sum = x + y
# diff = x - y
# mul = x * y
# expression = sum * diff + mul
# deadcode = expression * mul
# return expression

# foo.setup() # runs keygen/etc
# enc_x = foo.encrypt_x(7)
# enc_y = foo.encrypt_y(8)
# result_enc = foo.eval(enc_x, enc_y)
# result = foo.decrypt_result(result_enc)
# print(f"Expected result for `foo`: {foo(7,8)}, decrypted result: {result}")

# to replace with decorator
_heir_foo = pipeline.run_compiler(foo)

cc = _heir_foo.foo__generate_crypto_context()
kp = cc.KeyGen()
_heir_foo.foo__configure_crypto_context(cc, kp.secretKey)
arg0_enc = _heir_foo.foo__encrypt__arg0(cc, 7, kp.publicKey)
arg1_enc = _heir_foo.foo__encrypt__arg1(cc, 8, kp.publicKey)
res_enc = _heir_foo.foo(cc, arg0_enc, arg1_enc)
res = _heir_foo.foo__decrypt__result0(cc, res_enc, kp.secretKey)

print(res) # should be -15

# ### CKKS Example
# @compile(scheme="ckks") # other options default to backend="openfhe", debug=False
# def bar(x : Secret[F32], y : Secret[F32]):
# sum = x + y
# diff = x - y
# mul = x * y
# expression = sum * diff + mul
# deadcode = expression * mul
# return expression

# bar.setup() # runs keygen/etc
# enc_x = bar.encrypt_x(7)
# enc_y = bar.encrypt_y(8)
# result_enc = bar.eval(enc_x, enc_y)
# result = bar.decrypt_result(result_enc)
# print(f"Expected result for `bar`: {bar(7,8)}, decrypted result: {result}")




# ### Ciphertext-Plaintext Example
# @compile(debug=True)

# def baz(x: Secret[I16], y : I16):
# ptxt_mul = x * y
# ctxt_mul = x * x
# return ptxt_mul + ctxt_mul

# baz.setup() # runs keygen/etc
# enc_x = baz.encrypt_x(7)
# result_enc = baz.eval(enc_x, 8)
# result = baz.decrypt_result(result_enc)
# print(f"Expected result for `baz`: {baz(7,8)}, decrypted result: {result}")


### Ciphertext-Plaintext Example 2
@compile(debug=True)
def baz2(x: Secret[I16], y : Secret[I16], z : I16):
ptxt_mul = x * z
ctxt_mul = x * x
ctxt_mul2 = y * y
add = ctxt_mul + ctxt_mul2
return add + ptxt_mul
# TODO (#1284): Relin Opt works if ptxt_mul is RHS, but not if ptxt_mul is LHS?

baz2.setup() # runs keygen/etc
enc_x = baz2.encrypt_x(7)
enc_y = baz2.encrypt_y(8)
result_enc = baz2.eval(enc_x, enc_y, 9)
result = baz2.decrypt_result(result_enc)
print(f"Expected result for `baz2`: {baz(7,8,9)}, decrypted result: {result}")


# ### Matmul Exampl
# # secret.secret<tensor<4x4xf32>>
# @compile(scheme='ckks', debug=True)
# def qux(a : Secret[Tensor[4,4,F32]], b : Secret[Tensor[4,4,F32]]):
# AB = matmul(a,b)
# AABB = matmul(a+a, b+b)
# return AB + AABB

# a = np.array([[1,2],[3,4]])
# b = np.array([[5,6],[7,8]])
# print(qux(a,b))

# qux.setup()
# enc_a = qux.encrypt_a(a)
# enc_b = qux.encrypt_b(b)
# result_enc = qux.eval(enc_a, enc_b)
# result = qux.decrypt_result(result_enc)
# print(f"Expected result: {np.matmul(a,b)}, decrypted result: {result}")
9 changes: 9 additions & 0 deletions heir_py/mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .decorator import mlir
from numba.core import sigutils
import numpy as np

## Example of an overload: matmul
@mlir
def matmul(typingctx, X, Y):
# TODO (#1162): add a check if input types are valid!
return sigutils._parse_signature_string("float32[:,:](float32[:,:],float32[:,:])"), np.matmul
Loading
Loading