Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added Auto Opt, GPU and jax.Array #26

Merged
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
dependencies = [
"dace>=0.16",
"jax[cpu]>=0.4.24",
"jax[cpu]>=0.4.33",
"numpy>=1.26.0",
]
description = "JAX jit using DaCe (Data Centric Parallel Programming)"
Expand Down Expand Up @@ -103,6 +103,7 @@ module = [
"dace.*",
"jax.*",
"jaxlib.*",
"cupy.",
]

# -- pytest --
Expand Down
6 changes: 6 additions & 0 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@

from __future__ import annotations

import jax

import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry.

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit


if jax.version._version_as_tuple(jax.__version__) < (0, 4, 33):
raise ImportError(f"Require at least JAX version '0.4.33', but found '{jax.__version__}'.")


__all__ = [
"__author__",
"__copyright__",
Expand Down
34 changes: 19 additions & 15 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import functools
from collections.abc import Callable, Mapping
from typing import Literal, ParamSpec, TypedDict, TypeVar, overload
from typing import Any, Literal, ParamSpec, TypedDict, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack
Expand All @@ -22,44 +22,47 @@
__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")
_R = TypeVar("_R")


class JITOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.

Note:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
Not all arguments that are supported by `jax-jit()` are also supported by
`jace.jit`. Furthermore, some additional ones might be supported.
The following arguments are supported:
- `backend`: For which platform DaCe should generate code for. It is a string,
where the following values are supported: `'cpu'` or `'gpu'`.
DaCe's `DeviceType` enum or FPGA are not supported.
"""

backend: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend: str
backend: Literal['cpu', 'gpu']

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think that enforcing this through the annotations and MyPy is the right way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? What's the use case for the Literal type annotation then? Using Literal here also helps in documenting the supported values in a more prominent way.

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First a practical reason, if you read the backend string say from a file then you will get a str, so you will have to cast it.
It also does not really helps documenting since in general it does not explain what a particular value mean or does, you still have to read the doc anyway.



@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ...
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]]: ...


@overload
def jit(
fun: Callable[_P, _R],
fun: Callable[_P, Any],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> stages.JaCeWrapped[_P, _R]: ...
) -> stages.JaCeWrapped[_P]: ...


def jit(
fun: Callable[_P, _R] | None = None,
fun: Callable[_P, Any] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]:
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]] | stages.JaCeWrapped[_P]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.

Expand All @@ -72,18 +75,19 @@ def jit(
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
kwargs: Jit arguments, see `JITOptions` for more.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs: Jit arguments, see `JITOptions` for more.
kwargs: jit arguments, see `JITOptions` for more.

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be consistent and write JIT capital or Jit as it is the abbreviation of "Just in time" and we are at the beginning of a sentence.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use JIT then but it's up to you...


Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
not_supported_jit_keys = kwargs.keys() - {"backend"}
if not_supported_jit_keys:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(not_supported_jit_keys)}."
)

def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
56 changes: 46 additions & 10 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import TYPE_CHECKING, Final, TypedDict

import dace
from dace.transformation.auto import auto_optimize as dace_autoopt
from typing_extensions import Unpack


Expand All @@ -24,15 +26,19 @@


DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"auto_optimize": False,
"simplify": True,
"persistent_transients": True,
"validate": True,
"validate_all": False,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent_transients": False,
"validate": True,
"validate_all": False,
}


Expand All @@ -50,9 +56,15 @@ class CompilerOptions(TypedDict, total=False):
auto_optimize: bool
simplify: bool
persistent_transients: bool
validate: bool
validate_all: bool


def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 [undocumented-param]
def jace_optimize( # noqa: D417 [undocumented-param] # `kwargs` is not documented.
tsdfg: tjsdfg.TranslatedJaxprSDFG,
device: dace.DeviceType,
**kwargs: Unpack[CompilerOptions],
) -> None: # [undocumented-param]
"""
Performs optimization of the translated SDFG _in place_.

Expand All @@ -62,22 +74,46 @@ def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOp

Args:
tsdfg: The translated SDFG that should be optimized.
device: The device on which the SDFG will run on.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
auto_optimize: Run the auto optimization pipeline.
persistent_transients: Set the allocation lifetime of (non register) transients
in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated
between different invocations.
"""
# TODO(phimuell): Implement the functionality.
# Currently this function exists primarily for the sake of existing.
validate: Perform validation at the end.
validate_all: Perform extensive validation.

Note:
Currently DaCe's auto optimization pipeline is used when auto optimize is
enabled. However, it might change in the future. Because DaCe's auto
optimizer is considered unstable it must be explicitly enabled.
"""
assert device in {dace.DeviceType.CPU, dace.DeviceType.GPU}
simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)
validate = kwargs.get("validate", DEFAULT_OPTIMIZATIONS["validate"])
validate_all = kwargs.get("validate_all", DEFAULT_OPTIMIZATIONS["validate_all"])

if simplify:
tsdfg.sdfg.simplify()
tsdfg.sdfg.simplify(
validate=validate,
validate_all=validate_all,
)

if device == dace.DeviceType.GPU:
tsdfg.sdfg.apply_gpu_transformations(
validate=validate,
validate_all=validate_all,
simplify=True,
)

if auto_optimize:
pass

tsdfg.validate()
dace_autoopt.auto_optimize(
sdfg=tsdfg.sdfg,
device=device,
validate=validate,
validate_all=validate_all,
)

if validate or validate_all:
tsdfg.validate()
Loading