-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor(jit): make all jitted functions "inline=True" This should be more beneficial, as followed by the discussions in Jax repository, see jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342 * perf(pipeline): try to render 4/2/1 rows per batch using vmap to reduce fori_loop iterations * feat(_meta_utils): simple way to add multiple trace annotations together for functions add `@ad_tracing_name` to most functions to assist profiling also bump to Python 3.10 BREAKING CHANGE: Now requires Python 3.10 * perf(pipeline): big refactor to not updating per rows, but renders all rows then concat and merge * perf(pipeline): using scan + unroll (equiv map + unroll) This is very similar to map + vmap (minibatch processing) as the inner loop is too complex * build(pyproject): try to relax jax{lib}'s verion constraint to ">=0.3.25,<5.0.0" * test(pre-gen-brax): example inputs for profiling * perf: try to eliminate all `lax.cond` under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches, thus fails to 1) save computation when possible; 2) prevent unexpected values to be produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be executed anyway and only rule-out garbage value at the final stage all together to try to improve performance. See google/brax#8409 for more details about unconditional executation of cond under vmap * fix(pipeline): gl_FrontFacing: fix its determination in pipeline `True` if NOT back-facing * perf: added extra stage in pipeline, aiming to interpolate and shade only one fragment per pixel * docs(changelog): expose option `loop_unroll`; dependency version change Bump minimum Python version from 3.9 to 3.10; lower minimum jax & jaxlib to 0.3.25. * build(pyproject): bump to 0.3.0
- Loading branch information
Showing
21 changed files
with
582 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "jaxrenderer" | ||
version = "0.2.1" | ||
version = "0.3.0" | ||
description = "Jax implementation of rasterizer renderer." | ||
authors = ["Joey Teng <[email protected]>"] | ||
license = "Apache-2.0" | ||
|
@@ -32,10 +32,10 @@ include = [ | |
|
||
|
||
[tool.poetry.dependencies] | ||
python = "^3.9" | ||
jax = "^0.4.4" | ||
python = "^3.10" | ||
jax = ">=0.3.25,<5.0.0" | ||
numpy = "^1.22.0" | ||
jaxlib = {version = "^0.4.4", source = "jax"} | ||
jaxlib = {version = ">=0.3.25,<5.0.0", source = "jax"} | ||
jaxtyping = "^0.2.19" | ||
importlib-metadata = "^6.6.0" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import functools | ||
import inspect | ||
from typing import Callable, ParamSpec, TypeVar | ||
|
||
import jax | ||
|
||
ArgT = ParamSpec("ArgT") | ||
RetT = TypeVar("RetT") | ||
|
||
|
||
def add_tracing_name(func: Callable[ArgT, RetT]) -> Callable[ArgT, RetT]: | ||
"""Add tracing name to function.""" | ||
|
||
members: dict[str, str] | ||
members = dict(inspect.getmembers(func, lambda v: isinstance(v, str))) | ||
annotation: str = (f"{members.get('__module__', '')}" | ||
f":{members.get('__qualname__', '')}") | ||
|
||
@functools.wraps(func) | ||
def wrapper(*args: ArgT.args, **kwargs: ArgT.kwargs) -> RetT: | ||
with jax.named_scope(annotation): | ||
with jax.profiler.TraceAnnotation(annotation): | ||
return func(*args, **kwargs) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.