Skip to content

Commit

Permalink
0.3.0: Performance Improvement (#2)
Browse files Browse the repository at this point in the history
* refactor(jit): make all jitted functions "inline=True"

This should be more beneficial, as followed by the discussions in Jax repository, see
jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342

* perf(pipeline): try to render 4/2/1 rows per batch using vmap to reduce fori_loop iterations

* feat(_meta_utils): simple way to add multiple trace annotations together for functions

add `@ad_tracing_name` to most functions to assist profiling
also bump to Python 3.10

BREAKING CHANGE: Now requires Python 3.10

* perf(pipeline): big refactor to not updating per rows, but renders all rows then concat and merge

* perf(pipeline): using scan + unroll (equiv map + unroll)

This is very similar to map + vmap (minibatch processing) as the inner
loop is too complex

* build(pyproject): try to relax jax{lib}'s verion constraint to ">=0.3.25,<5.0.0"

* test(pre-gen-brax): example inputs for profiling

* perf: try to eliminate all `lax.cond`

under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches,
thus fails to 1) save computation when possible; 2) prevent unexpected values to be
produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be
executed anyway and only rule-out garbage value at the final stage all together to try to improve
performance. See google/brax#8409 for more details about unconditional executation of cond under
vmap

* fix(pipeline): gl_FrontFacing: fix its determination in pipeline

`True` if NOT back-facing

* perf: added extra stage in pipeline, aiming to interpolate and shade only one fragment per pixel

* docs(changelog): expose option `loop_unroll`; dependency version change

Bump minimum Python version from 3.9 to 3.10;
lower minimum jax & jaxlib to 0.3.25.

* build(pyproject): bump to 0.3.0
  • Loading branch information
JoeyTeng authored Jun 12, 2023
1 parent 69cd51a commit 788788a
Show file tree
Hide file tree
Showing 21 changed files with 582 additions and 241 deletions.
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@

1. Refactor `Scene.set_object_*` methods to be a simple wrapper of `self._replace` and `ModelObject.replace_with_*`, to expose APIs of `ModelObject`s and allows manipulation and rendering without `Scene`.
2. Expose `create_capsule` and `create_cube` APIs.

## 0.3.0

1. Fix `gl_FrontFacing` computation in pipeline so it is consistent to comment: `True` if not backfacing (i.e. frontfacing & side facing).
2. Add an extra stage `Shader.primitive_chooser` to choose which primitive to be rendered for each fragment. The default implementation is provided, which assumes that the depth is just the interpolated `z` value in the eye space. It just picks the values of the single primitive that is closest to the camera and is not discarded in the previous pipeline.
3. Expose `loop_unroll` static option to allow unrolling several operations (row rendering) within a single iteration of the outmost loop (iterating along first axis of the canvas). This may be useful in some cases for performance improvement, but careful benchmarking is needed to determine the optimal value. The default value is `1` (no unrolling) as it is the most general case in larger canvases (benchmarked on `960x540` using [GPU T4 in Colab](https://colab.research.google.com/drive/1xhkYNz5WjvUCjQWpp72CLf9SIy3i5PnN)).
4. Bump the minimum Python version to Python 3.10
5. Lower the minimum jax & jaxlib version to 0.3.25.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxrenderer"
version = "0.2.1"
version = "0.3.0"
description = "Jax implementation of rasterizer renderer."
authors = ["Joey Teng <[email protected]>"]
license = "Apache-2.0"
Expand Down Expand Up @@ -32,10 +32,10 @@ include = [


[tool.poetry.dependencies]
python = "^3.9"
jax = "^0.4.4"
python = "^3.10"
jax = ">=0.3.25,<5.0.0"
numpy = "^1.22.0"
jaxlib = {version = "^0.4.4", source = "jax"}
jaxlib = {version = ">=0.3.25,<5.0.0", source = "jax"}
jaxtyping = "^0.2.19"
importlib-metadata = "^6.6.0"

Expand Down
12 changes: 10 additions & 2 deletions renderer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,18 @@ The fragments are generated for each position in the zbuffer, and if the screen

The processing is done by iterating along the first axis of the buffer, and `vmap` along the second axis, and `vmap` along the primitives. Thus, all fragments at the same position are generated together, then mixed, then written to the buffer.

Interpolation of all attributes for each fragment is defined by `Shader.interpolate`. The default implementation is provided, which simply linearly interpolates the attributes in the clip space according to the barycentric coordinates of the fragment. This behaviour is wrapped as a convenient function `interpolate` and mode `Interpolation.SMOOTH`. The interpolated values are then passed to the fragment shader.

Currently no anti-aliasing strategy is supported.

#### Optional Early Depth Test

This is an additional stage in this pipeline which may be analogical to the early depth test in OpenGL. It is implemented in `Shader.primitive_chooser`. The default implementation is provided, which assumes that the depth is just the interpolated `z` value in the eye space. It just picks the values of the single primitive that is closest to the camera and is not discarded in the previous pipeline.

Custom implementations can overload to change this behaviour to achieve special effects like occlusion, transparency, etc. Note that the number of returned primitives must be static, i.e., same for all fragments, as required by `jax.jit`.

#### Interpolation of Attributes

Interpolation of all attributes for each fragment is defined by `Shader.interpolate`. The default implementation is provided, which simply linearly interpolates the attributes in the clip space according to the barycentric coordinates of the fragment. This behaviour is wrapped as a convenient function `interpolate` and mode `Interpolation.SMOOTH`. The interpolated values are then passed to the fragment shader. Currently only `Interpolation.SMOOTH` and `Interpolation.FLAT` are supported.

Reference:

- For interpolation modes, see [Interpolation Qualifiers](https://www.khronos.org/opengl/wiki/Type_Qualifier_(GLSL)#Interpolation_qualifiers)
Expand Down
25 changes: 25 additions & 0 deletions renderer/_meta_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import functools
import inspect
from typing import Callable, ParamSpec, TypeVar

import jax

ArgT = ParamSpec("ArgT")
RetT = TypeVar("RetT")


def add_tracing_name(func: Callable[ArgT, RetT]) -> Callable[ArgT, RetT]:
"""Add tracing name to function."""

members: dict[str, str]
members = dict(inspect.getmembers(func, lambda v: isinstance(v, str)))
annotation: str = (f"{members.get('__module__', '')}"
f":{members.get('__qualname__', '')}")

@functools.wraps(func)
def wrapper(*args: ArgT.args, **kwargs: ArgT.kwargs) -> RetT:
with jax.named_scope(annotation):
with jax.profiler.TraceAnnotation(annotation):
return func(*args, **kwargs)

return wrapper
92 changes: 62 additions & 30 deletions renderer/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float, Integer, Num, jaxtyped

from ._meta_utils import add_tracing_name
from .types import Triangle2Df, Vec2f, Vec3f, Vec4f

# Transform matrix that takes a batch of homogeneous 3D vertices and transform
Expand All @@ -28,7 +29,8 @@


@jaxtyped
@partial(jax.jit, donate_argnums=(0, ))
@partial(jax.jit, donate_argnums=(0, ), inline=True)
@add_tracing_name
def normalise(vector: Float[Array, "dim"]) -> Float[Array, "dim"]:
"""normalise vector in-place."""
return vector / jnp.linalg.norm(vector)
Expand All @@ -48,7 +50,8 @@ class Interpolation(enum.Enum):
"""Perspective correction: linear interpolation in clip space"""

@jaxtyped
@partial(jax.jit, static_argnames=("self", ))
@partial(jax.jit, static_argnames=("self", ), inline=True)
@add_tracing_name
def __call__(
self,
values: Num[Array, "3 *valueDimensions"],
Expand Down Expand Up @@ -92,7 +95,8 @@ def __call__(


@jaxtyped
@partial(jax.jit, static_argnames=("mode", ))
@partial(jax.jit, static_argnames=("mode", ), inline=True)
@add_tracing_name
def interpolate(
values: Num[Array, "3 *valueDimensions"],
barycentric_screen: Vec3f,
Expand All @@ -115,7 +119,8 @@ def interpolate(


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def to_homogeneous(
coordinates: Float[Array, "*batch dim"],
value: Float[Array, "*batch"] = jnp.array(1.),
Expand All @@ -138,7 +143,8 @@ def to_homogeneous(


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def normalise_homogeneous(
coordinates: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]:
"""Transform the homogenous coordinates to make the scale factor equals to
Expand All @@ -152,7 +158,8 @@ def normalise_homogeneous(


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def to_cartesian(
coordinates: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim-1"]:
"""Transform the homogenous coordinates to cartesian coordinates by divide
Expand Down Expand Up @@ -192,7 +199,8 @@ class Camera(NamedTuple):

@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@partial(jax.jit, static_argnames=("cls", ), inline=True)
@add_tracing_name
def create(
cls,
view: View,
Expand Down Expand Up @@ -240,7 +248,8 @@ def create(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def apply(
points: Num[Array, "*N 4"],
matrix: Num[Array, "4 4"],
Expand Down Expand Up @@ -275,7 +284,8 @@ def apply(

@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@partial(jax.jit, static_argnames=("cls", ), inline=True)
@add_tracing_name
def apply_pos(
cls,
points: Num[Array, "*N 3"],
Expand Down Expand Up @@ -305,7 +315,8 @@ def apply_pos(

@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@partial(jax.jit, static_argnames=("cls", ), inline=True)
@add_tracing_name
def apply_vec(
cls,
vectors: Num[Array, "*N 3"],
Expand Down Expand Up @@ -343,7 +354,8 @@ def apply_vec(
return transformed_normalised

@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def to_screen(
self,
points: Num[Array, "*N 4"],
Expand All @@ -368,7 +380,8 @@ def to_screen(
return normalised

@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def to_clip(
self,
points: Num[Array, "*N 4"],
Expand All @@ -390,7 +403,8 @@ def to_clip(
return clip_space

@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def to_screen_inv(
self,
screen: Float[Array, "*N 4"],
Expand Down Expand Up @@ -442,7 +456,8 @@ def to_screen_inv(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def inv_scale_translation_matrix(
scale_translation_mat: Float[Array, "4 4"]) -> Float[Array, "4 4"]:
"""Compute the inverse matrix of a (4, 4) matrix representing a scale and translation, in a form of:
Expand Down Expand Up @@ -483,7 +498,8 @@ def inv_scale_translation_matrix(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def view_matrix(
eye: Vec3f,
centre: Vec3f,
Expand Down Expand Up @@ -524,7 +540,8 @@ def view_matrix(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def view_matrix_inv(
eye: Vec3f,
centre: Vec3f,
Expand Down Expand Up @@ -577,7 +594,8 @@ def view_matrix_inv(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def perspective_projection_matrix(
fovy: jnp.floating[Any],
aspect: jnp.floating[Any],
Expand Down Expand Up @@ -621,7 +639,8 @@ def perspective_projection_matrix(

@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@partial(jax.jit, static_argnames=("cls", ), inline=True)
@add_tracing_name
def perspective_projection_matrix_inv(cls, mat: Projection) -> Projection:
"""Create the inverse of a perspective projection matrix as defined in
`perspective_projection_matrix`.
Expand Down Expand Up @@ -650,7 +669,8 @@ def perspective_projection_matrix_inv(cls, mat: Projection) -> Projection:

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def orthographic_projection_matrix(
left: jnp.floating[Any],
right: jnp.floating[Any],
Expand Down Expand Up @@ -694,6 +714,7 @@ def orthographic_projection_matrix(
@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@add_tracing_name
def orthographic_projection_matrix_inv(cls, mat: Projection) -> Projection:
"""Create the inverse of a orthographic projection matrix as defined in
`orthographic_projection_matrix`. Since orthographic projection
Expand All @@ -707,7 +728,8 @@ def orthographic_projection_matrix_inv(cls, mat: Projection) -> Projection:

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def perspective_projection_matrix_tinyrenderer(
eye: Vec3f,
centre: Vec3f,
Expand All @@ -734,7 +756,8 @@ def perspective_projection_matrix_tinyrenderer(

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def viewport_matrix(
lowerbound: Num[Array, "2"],
dimension: Integer[Array, "2"],
Expand Down Expand Up @@ -766,7 +789,8 @@ def viewport_matrix(

@classmethod
@jaxtyped
@partial(jax.jit, static_argnames=("cls", ))
@partial(jax.jit, static_argnames=("cls", ), inline=True)
@add_tracing_name
def viewport_matrix_inv(cls, viewport: Viewport) -> Viewport:
"""Create the inverse of a viewport matrix as defined in `viewport_matrix`.
Expand All @@ -782,7 +806,8 @@ def viewport_matrix_inv(cls, viewport: Viewport) -> Viewport:

@staticmethod
@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def world_to_screen_matrix(width: int, height: int) -> World2Screen:
"""Generate the projection matrix to convert model coordinates to
screen/canvas coordinates.
Expand All @@ -804,7 +829,8 @@ def world_to_screen_matrix(width: int, height: int) -> World2Screen:


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def compute_normal(triangle_verts: Float[Array, "3 3"]) -> Float[Array, "3"]:
normal: Float[Array, "3"] = jnp.cross(
triangle_verts[2, :] - triangle_verts[0, :],
Expand All @@ -817,13 +843,15 @@ def compute_normal(triangle_verts: Float[Array, "3 3"]) -> Float[Array, "3"]:


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def compute_normals(batch_verts: Float[Array, "b 3 3"]) -> Float[Array, "b 3"]:
return jax.vmap(compute_normal)(batch_verts)


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def quaternion(
rotation_axis: Union[Vec3f, tuple[float, float, float]],
rotation_angle: Union[Float[Array, ""], float],
Expand All @@ -850,7 +878,8 @@ def quaternion(


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def quaternion_mul(quatA: Vec4f, quatB: Vec4f) -> Vec4f:
"""Multiply two quaternion rotations, as to composite them.
Expand All @@ -877,7 +906,8 @@ def quaternion_mul(quatA: Vec4f, quatB: Vec4f) -> Vec4f:


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def rotation_matrix(
rotation_axis: Union[Vec3f, tuple[float, float, float]],
rotation_angle: Union[Float[Array, ""], float],
Expand Down Expand Up @@ -908,7 +938,8 @@ def rotation_matrix(


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]:
"""Generate a transform matrix from a quaternion rotation.
Expand Down Expand Up @@ -936,7 +967,8 @@ def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]:


@jaxtyped
@jax.jit
@partial(jax.jit, inline=True)
@add_tracing_name
def barycentric(pts: Triangle2Df, p: Vec2f) -> Vec3f:
"""Compute the barycentric coordinate of `p`.
Returns u[-1] < 0 if `p` is outside of the triangle.
Expand Down
Loading

0 comments on commit 788788a

Please sign in to comment.