Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Finalizing the initial infrastructure of JaCe #18

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit
from .translated_jaxpr_sdfg import CompiledJaxprSDFG, TranslatedJaxprSDFG


__all__ = [
"CompiledJaxprSDFG",
"TranslatedJaxprSDFG",
"__author__",
"__copyright__",
"__license__",
Expand Down
59 changes: 42 additions & 17 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Literal, overload
import inspect
from typing import TYPE_CHECKING, Literal, ParamSpec, TypedDict, TypeVar, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator

Expand All @@ -21,57 +23,80 @@
from collections.abc import Callable, Mapping


__all__ = ["grad", "jacfwd", "jacrev", "jit"]
__all__ = ["JitOptions", "grad", "jacfwd", "jacrev", "jit"]

# Used for type annotation, see the notes in `jace.stages` for more.
_P = ParamSpec("_P")
_RetrunType = TypeVar("_RetrunType")


class JitOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.

Notes:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
"""


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaCeWrapped]: ...
**kwargs: Unpack[JitOptions],
) -> Callable[[Callable[_P, _RetrunType]], stages.JaCeWrapped[_P, _RetrunType]]: ...


@overload
def jit(
fun: Callable,
fun: Callable[_P, _RetrunType],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped: ...
**kwargs: Unpack[JitOptions],
) -> stages.JaCeWrapped[_P, _RetrunType]: ...


def jit(
fun: Callable | None = None,
fun: Callable[_P, _RetrunType] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
**kwargs: Unpack[JitOptions],
) -> (
Callable[[Callable[_P, _RetrunType]], stages.JaCeWrapped[_P, _RetrunType]]
| stages.JaCeWrapped[_P, _RetrunType]
):
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.

It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
It works the same way as `jax.jit` does, but instead of lowering the
computation to XLA, it is lowered to DaCe.
The function supports a subset of the arguments that are accepted by `jax.jit()`,
currently none, and some JaCe specific ones.

Args:
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.

Notes:
After constructions any change to `primitive_translators` has no effect.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaCeWrapped:
# TODO(egparedes): Improve typing.
def wrapper(f: Callable[_P, _RetrunType]) -> stages.JaCeWrapped[_P, _RetrunType]:
if any(
param.default is not param.empty for param in inspect.signature(f).parameters.values()
):
raise NotImplementedError("Default values are not yet supported.")

jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
26 changes: 17 additions & 9 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause

"""
JaCe specific optimizations.

Currently just a dummy exists for the sake of providing a callable function.
"""
"""JaCe specific optimizations."""

from __future__ import annotations

Expand All @@ -19,7 +15,7 @@


if TYPE_CHECKING:
from jace import translator
import jace


class CompilerOptions(TypedDict, total=False):
Expand All @@ -35,15 +31,24 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent: bool


# TODO(phimuell): Add a context manager to modify the default.
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True}
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"simplify": True,
"persistent": True,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False}
NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent": False,
}


def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
def jace_optimize(tsdfg: jace.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
"""
Performs optimization of the translated SDFG _in place_.

Expand All @@ -55,6 +60,9 @@ def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[Compil
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
persistent: Make the memory allocation persistent, i.e. allocate the
transients only once at the beginning and then reuse the memory across
the lifetime of the SDFG.
"""
# Currently this function exists primarily for the same of existing.

Expand Down
Loading