Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial implementation #3

Merged

Conversation

philip-paul-mueller
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller commented Apr 21, 2024

Initial JaCe code.

This PR introduces a series of very basic functionalities, that will be extended in subsequent PRs.
Most importantly it introduces the basic infrastructure that allows to translate a Jaxpr object into an SDFG.
Furthermore, it also adds the jace.jit decorator, which can be used as a replacement for jax.jit.
However, a function that was decorated with jace.jit remains fully composable with Jax transformation, such as jax.grad or jax.jacfwd.
But, the functionality is still very basic and jace.jit, essentially does not accepts any arguments and only works on CPU.
Nevertheless, there is a cache for caching tracing, translation and compilation of wrapped functions.

Although, this PR introduces the components for the translation, the actual primitive translators are not yet implemented (this commit adds the ALUTranslator, but this one was back ported from the prototype for simple tests).

As a last point the tests, located in tests are only a first version, that were not included in the review of this PR.
They will be reviewed at a later stage

For more information see the ROADMAP.md file.

…rom extern, without the need of having a Jax variable.
…ables.

This is basically a simple substitute class that can be used instead of a full jax variable.
It is mostly usefull for creating arrays during testing.
However, it should also be used to cretae variables for which we do not have anything.

This essentially replaces the flags that allowed to explicitly specify shape and dtype in the `add_array()` function.
But we should split that thing and make it nicer.
The output memlet for the scalar case was not handled correctly.
…one go.

This function is highly internal and should not be used.
The function is now able to handle more recent Jax functions, however, it now produces very starnge names.
I think we should have some kind of context, similar to the one used by Jax itself.
…n replacement.

I think that we will need it in the future, for example it is, in my view, currently the best place to put `jace.jit`.
Another idea that could be worthwhile to consider is to mimick the `jace` package after `jax`.
However, it does not cache the compiled code yet.
And it is not tested yet.
There is still the problem of scalars as return values.
However, even DaCe turns them into arrays.
@codecov-commenter
Copy link

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good for now. Let's merge it and keep working in adding features and improving the code.

@philip-paul-mueller philip-paul-mueller merged commit 1ebbcf3 into GridTools:main Jun 18, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants