Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Finalizing the initial infrastructure of JaCe #18

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion CODING_GUIDELINES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the

- According to subsection [_3.19.12 Imports For Typing_](https://google.github.io/styleguide/pyguide.html#31912-imports-for-typing), symbols from `typing` and `collections.abc` modules used in type annotations _"can be imported directly to keep common annotations concise and match standard typing practices"_. Following the same spirit, we allow symbols to be imported directly from third-party or internal modules when they only contain a collection of frequently used typying definitions.

### Aliasing of Modules

According to subsection [2.2](https://google.github.io/styleguide/pyguide.html#22-imports) in certain cases it is allowed to introduce an alias for an import.
Inside JaCe the following convention is applied:

- If the module has a standard abbreviation use that, e.g. `import numpy as np`.
- For a JaCe module use:
- If the module name is only a single word use it directly, e.g. `from jace import translator`.
- If the module name consists of multiple words use the last word prefixed with the first letters of the others, e.g. `from jace.translator import post_translator as ptranslator` or `from jace import translated_jaxpr_sdfg as tjsdfg`.
- In case of a clash use your best judgment.
- For an external module use the rule above, but prefix the name with the main package's name, e.g. `from dace.codegen import compiled_sdfg as dace_csdfg`.

### Python usage recommendations

- `pass` vs `...` (`Ellipsis`)
Expand Down Expand Up @@ -104,7 +116,7 @@ We generate the API documentation automatically from the docstrings using [Sphin
Sphinx supports the [reStructuredText][sphinx-rest] (reST) markup language for defining additional formatting options in the generated documentation, however section [_3.8 Comments and Docstrings_](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) of the Google Python Style Guide does not specify how to use markups in docstrings. As a result, we decided to forbid reST markup in docstrings, except for the following cases:

- Cross-referencing other objects using Sphinx text roles for the [Python domain](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#the-python-domain) (as explained [here](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#python-roles)).
- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. ``` ``literal`` ``` strings, bulleted lists).
- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. `` `literal` `` strings, bulleted lists).

We highly encourage the [doctest] format for code examples in docstrings. In fact, doctest runs code examples and makes sure they are in sync with the codebase.

Expand Down
2 changes: 1 addition & 1 deletion src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

import jace.translator.primitive_translators as _ # noqa: F401 # Populate the internal registry.
import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry.

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit
Expand Down
50 changes: 31 additions & 19 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,68 +10,80 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Literal, overload
from collections.abc import Callable, Mapping
from typing import Literal, ParamSpec, TypedDict, TypeVar, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator


if TYPE_CHECKING:
from collections.abc import Callable, Mapping
__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")
_R = TypeVar("_R")

__all__ = ["grad", "jacfwd", "jacrev", "jit"]

class JITOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.

Note:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
"""


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaCeWrapped]: ...
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ...


@overload
def jit(
fun: Callable,
fun: Callable[_P, _R],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped: ...
**kwargs: Unpack[JITOptions],
) -> stages.JaCeWrapped[_P, _R]: ...


def jit(
fun: Callable | None = None,
fun: Callable[_P, _R] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.

It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
It works the same way as `jax.jit` does, but instead of lowering the
computation to XLA, it is lowered to DaCe.
The function supports a subset of the arguments that are accepted by `jax.jit()`,
currently none, and some JaCe specific ones.

Args:
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.

Notes:
After constructions any change to `primitive_translators` has no effect.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaCeWrapped:
# TODO(egparedes): Improve typing.
def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
33 changes: 23 additions & 10 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"""
JaCe specific optimizations.

Currently just a dummy exists for the sake of providing a callable function.
Todo:
Organize this module once it is a package.
"""

from __future__ import annotations
Expand All @@ -19,7 +20,20 @@


if TYPE_CHECKING:
from jace import translator
from jace import translated_jaxpr_sdfg as tjsdfg


DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"simplify": True,
"persistent_transients": True,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent_transients": False,
}


class CompilerOptions(TypedDict, total=False):
Expand All @@ -35,15 +49,10 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent_transients: bool


# TODO(phimuell): Add a context manager to modify the default.
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False}


def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 [undocumented-param]
"""
Performs optimization of the translated SDFG _in place_.

Expand All @@ -55,8 +64,12 @@ def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[Compil
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
persistent_transients: Set the allocation lifetime of (non register) transients
in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated
between different invocations.
"""
# Currently this function exists primarily for the same of existing.
# TODO(phimuell): Implement the functionality.
# Currently this function exists primarily for the sake of existing.

simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)
Expand Down
Loading