Skip to content

Commit

Permalink
bump to 0.3.2: docs, add example, fix typing bug (#13)
Browse files Browse the repository at this point in the history
* chore(.git-blame-ignore-revs): ignore formatting commits: adapting to isort, black, pyright

* docs(readme): add badges, installation, usage, gallery

* style(github-action): update CI names for better display in README badges

* fix(geometry): type hinting bugs

* fix(model): bug in merge specular maps

* feat(examples): simple_cube: render a simple blue cube

* build(pyproject): bump minimum jax, jaxlib version to 0.4.0

* docs(changelog): for 0.3.2

* ci(pyproject): bump to 0.3.2
  • Loading branch information
JoeyTeng authored Jul 21, 2023
1 parent f5fc117 commit 55999f6
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 22 deletions.
6 changes: 6 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Migrate code style with isort
ab8273d47c979b2a53d1e53908843f2055178686
# Migrate code style to Black
a0e157decc3a729bdec493c9c432ad3a9a314175
# To pass Pyright type checks
b77a15efc65f7750cefd89925feb855efe881fa0
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Checks
name: lint & test

on:
push:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package
name: build

on:
release:
Expand Down
112 changes: 109 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,112 @@
# JAX Renderer
# JAX Renderer: Differentiable Rendering in Batch on Accelerators

[![PyPI Version](https://img.shields.io/pypi/v/jaxrenderer?logo=pypi)](https://pypi.org/project/jaxrenderer)
[![Python Versions](https://img.shields.io/pypi/pyversions/jaxrenderer?logo=python)](https://pypi.org/project/jaxrenderer)
[![License](https://img.shields.io/github/license/JoeyTeng/jaxrenderer)](https://github.com/JoeyTeng/jaxrenderer/blob/master/LICENSE)
[![Build & Publish](https://github.com/JoeyTeng/jaxrenderer/actions/workflows/pypi.yml/badge.svg)](https://github.com/JoeyTeng/jaxrenderer/actions/workflows/pypi.yml)
[![Lint & Test](https://github.com/JoeyTeng/jaxrenderer/actions/workflows/checks.yml/badge.svg)](https://github.com/JoeyTeng/jaxrenderer/actions/workflows/checks.yml)
[![Checked with pyright](https://microsoft.github.io/pyright/img/pyright_badge.svg)](https://microsoft.github.io/pyright/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Poetry](https://img.shields.io/endpoint?url=https://python-poetry.org/badge/v0.json&label=packaging)](https://python-poetry.org/)
[![Open in Colab](https://img.shields.io/badge/%7F-Open_demo_in_Colab-blue.svg?logo=googlecolab)](https://colab.research.google.com/github/JoeyTeng/jaxrenderer/blob/master/notebooks/Demo.ipynb)

JaxRenderer is a differentiable renderer implemented in [JAX](https://github.com/google/jax), which supports differentiable rendering and batch rendering on accelerators (e.g. GPU, TPU) using simple function transformations provided by JAX. It is designed to replace by [erwincoumans/tinyrenderer](https://github.com/erwincoumans/tinyrenderer) in [BRAX](https://github.com/google/brax) to support visualising simulation results through fast rendering on accelerators with no external dependencies (other than JAX).

You may find the [slides](https://github.com/JoeyTeng/jaxrenderer/blob/master/docs/final%20presentation%20slides.pdf) of my final year project presentation useful, where I gave a brief introduction to the renderer and the implementation details, including the design of the pipeline and comparing it with the OpenGL's.

## Installation

This project is distributed in [PyPI](https://pypi.org/project/jaxrenderer), and can be installed simply using `pip`:

```bash
pip install jaxrenderer
```

The minimum Python version is `3.8`, and the minimum JAX version is `0.4.0`. You may need to install `jaxlib` separately if you are using GPU or TPU; by default, the CPU version of jaxlib is installed. Please refer to [JAX's installation guide](https://github.com/google/jax#installation) for more details.

## Usage

> Please note that the package is imported with name `renderer` rather than the PyPI package name `jaxrenderer`. This may change in the future though.
Some example scripts are provided in [examples](examples) folder. You may find the [demo notebook](notebooks/Demo.ipynb) useful as well. In the demo, there is batch rendering and differentiable rendering examples.

The following is a simple example of rendering a cube with a texture map:

```python
import jax.numpy as jnp
import renderer


ImageWidth: int = 640
ImageHeight: int = 480

# Create a cube with texture map of pure blue
cube = renderer.create_cube(
half_extents=jnp.ones(3, dtype=jnp.single),
texture_scaling=jnp.ones(2, dtype=jnp.single),
# pure blue texture map
diffuse_map=jnp.zeros((2, 2, 3), dtype=jnp.single).at[..., 2].set(1),
specular_map=jnp.ones((2, 2), dtype=jnp.single) * 2.0,
)

# Render the cube
image = renderer.Renderer.get_camera_image(
objects=[renderer.ModelObject(model=cube)],
# Simply use defaults
camera=renderer.CameraParameters(
viewWidth=ImageWidth,
viewHeight=ImageHeight,
position=jnp.array([2.0, 4.0, 1.0], dtype=jnp.single),
),
# Simply use defaults
light=renderer.LightParameters(),
width=ImageWidth,
height=ImageHeight,
)
```

You may refer to [demo](https://colab.research.google.com/github/JoeyTeng/jaxrenderer/blob/master/notebooks/Demo.ipynb) for more complex examples, including differentiable rendering and batch rendering.

### Supported Shaders

#### Built-in Shaders

See [`renderer/shaders`](renderer/shaders) for more details.

| Shader Name | Description | Light Direction |
| ----------- | ----------- | --------------- |
| depth | Depth Shader, outputs only screen-space depth value | N.A. |
| gouraud | Gouraud Shading, interpolates vertex colour and outputs it as fragment colour | In model space |
| gouraud_texture | Gouraud Shading with Texture, interpolates vertex colour and samples texture map in fragment shader | In model space |
| phong | Phong Shading, interpolates vertex normal and computes light direction in fragment shader | In eye space, like "head light" |
| phong_darboux | Phong Shading with Normal Map in Tangent Space, interpolates vertex normal and computes light direction in fragment shader, and samples normal map in tangent space | In eye space, like "head light" |
| phong_reflection | Phong Shading with Phong Reflection Approximation, interpolates vertex normal and computes light direction in fragment shader, and samples texture map and specular map in fragment shader | In eye space |
| phong_reflection_shadow | Phong Shading with Phong Reflection Approximation and Shadow, interpolates vertex normal and computes light direction in fragment shader, samples texture map and specular map in fragment shader, and tests shadow in fragment shader | In eye space |

#### Custom Shaders

You may implement your own shaders by inheriting from `Shader` and implement the following methods:

- `vertex`: this is like vertex shader in OpenGL; it must be overridden.
- `primitive_chooser`: at this stage the visibility at each pixel level is tested, it works like pre-z test in OpenGL, makes the pipeline works like a deferred shading pipeline. Noted that you may override and return more than one primitive for each pixel to support transparency. The default implementation simply chooses the primitive with minimum z value (depth).
- `interpolate`: this controls how attributes associated with a fragment is interpolated from the vertices; the default implementation interpolates everything using perspective interpolation.
- `fragment`: this is like fragment shader in OpenGL; a default implementation is provided if you do not need to process any data in the fragment shader.
- `mix`: this is like blending stage in OpenGL; the default implementation simple uses the data from the fragment with minimum screen-space z value (depth).

## Gallery

![Batch Rendering Example, 30 Ants inference on A100 GPU with 90 iterations, rendered onto 84x84 canvas in 5.26s](docs/assets/84x84%2030ants%2090f%2030fps.gif)
> Batch Rendering Example, 30 Ants inference on A100 GPU with 90 iterations, rendered onto 84x84 canvas in 5.26s.
![Phong Reflection Model + Hard Shadow, 30 frames 1920x1080, 2492 triangles in 9.25s](docs/assets/head.gif)
> Phong Reflection Model + Hard Shadow, 30 frames 1920x1080, 2492 triangles in 9.25s.
![Differentiable Rendering Toy Example, deduce light colour parameters](docs/assets/differentiable%20rendering.gif)
> Differentiable Rendering Toy Example, deduce light colour parameters.
## Key Difference from [erwincoumans/tinyrenderer](https://github.com/erwincoumans/tinyrenderer)

- Native JAX implementation, supports `jit`, `vmap`, etc.
- Native JAX implementation, supports `jit`, `vmap`, `grad`, etc.
- Lighting is computed in main camera's eye space; while in PyTinyrenderer it is computed in world space.
- Texture specification is different: in PyTinyrenderer, the texture is specified in a flattened array, while in JAX Renderer, the texture is specified in a shape of (width, height, colour channels). A simple way to transform old specification to new specification is to use the convenient method `build_texture_from_PyTinyrenderer`.
- Rendering pipeline is different. PyTinyrenderer renders one object at a time, and share zbuffer and framebuffer across one pass. This renderer first merges all objects into one big mesh in world space, then process all vertices together, then interpolates and rasterise and render. For fragment shading, this is done by sweeping each row in a for loop, and batch compute all pixels together. For computing a pixel, all fragments for that pixels are batch compute together, then mixed. This is more memory efficient and allows `vmap` batching as far as possible.
Expand All @@ -18,6 +122,8 @@

- [ ] Support double-sided objects
- [ ] Profile and accelerate implementation
- [ ] Differentiable rendering
- [ ] Build a ray tracer as well
- [ ] Differentiable rendering with respect to mesh
- [x] Differentiable rendering with respect to light parameters
- [x] Differentiable rendering with respect to camera parameters _(not tested)_
- [ ] <s>Correctly implement a proper clipping algorithm</s>
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,11 @@

1. Lower minimum Python version to 3.8
2. Introducing `type_extensions` package and improved typing annotations.

## 0.3.2

1. Bump minimum `jax` and `jaxlib` version to 0.4.0 as `jaxtyping` does not support `jax` 0.3.25.
2. Bug fix: add `static_argnames` for utility function `transpose_for_display`.
3. Change to [isort](https://github.com/PyCQA/isort) + [black](https://github.com/psf/black) code style.
4. Migrate full codebase to be type-checked with [pyright](https://github.com/microsoft/pyright).
5. Add smoke tests, and use GitHub Action as CI to run them.
Binary file added docs/assets/84x84 30ants 90f 30fps.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/differentiable rendering.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/head.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 52 additions & 0 deletions examples/simple_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import jax.numpy as jnp

import renderer

ImageWidth: int = 640
ImageHeight: int = 480

# Create a cube with texture map of pure blue
cube = renderer.create_cube(
half_extents=jnp.ones( # pyright: ignore[reportUnknownMemberType]
3, dtype=jnp.single
),
texture_scaling=jnp.ones( # pyright: ignore[reportUnknownMemberType]
2, dtype=jnp.single
),
diffuse_map=jnp.zeros( # pyright: ignore[reportUnknownMemberType]
(2, 2, 3), dtype=jnp.single
)
.at[..., 2]
.set(1),
specular_map=jnp.ones( # pyright: ignore[reportUnknownMemberType]
(2, 2), dtype=jnp.single
)
* 2.0,
)

# Render the cube
image = renderer.Renderer.get_camera_image(
objects=[renderer.ModelObject(model=cube)],
# Simply use defaults
camera=renderer.CameraParameters(
viewWidth=ImageWidth,
viewHeight=ImageHeight,
position=jnp.array( # pyright: ignore[reportUnknownMemberType]
[2.0, 4.0, 1.0], dtype=jnp.single
),
),
# Simply use defaults
light=renderer.LightParameters(),
width=ImageWidth,
height=ImageHeight,
)

import matplotlib.pyplot as plt

fig, ax = plt.subplots() # pyright: ignore

ax.imshow( # pyright: ignore[reportUnknownMemberType]
renderer.transpose_for_display(image)
)

plt.show() # pyright: ignore[reportUnknownMemberType]
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxrenderer"
version = "0.3.1"
version = "0.3.2"
description = "Jax implementation of rasterizer renderer."
authors = ["Joey Teng <[email protected]>"]
license = "Apache-2.0"
Expand Down Expand Up @@ -33,9 +33,9 @@ include = [

[tool.poetry.dependencies]
python = "^3.8"
jax = ">=0.3.25,<5.0.0"
jax = "^0.4.0"
numpy = "^1.22.0"
jaxlib = {version = ">=0.3.25,<5.0.0", source = "jax"}
jaxlib = {version = "^0.4.0", source = "jax"}
jaxtyping = [
{version = ">=0.2.13,<0.2.20", python = ">=3.8,<3.9"},
{version = "^0.2.19", python = "^3.9"}
Expand Down
12 changes: 6 additions & 6 deletions renderer/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
@jaxtyped
@partial(jit, donate_argnums=(0,), inline=True)
@add_tracing_name
def normalise(vector: Float[Array, "dim"]) -> Float[Array, "dim"]:
def normalise(vector: Float[Array, "*a dim"]) -> Float[Array, "*a dim"]:
"""normalise vector in-place."""
result: Float[Array, "dim"] = cast(
Float[Array, "dim"],
result: Float[Array, "*a dim"] = cast(
Float[Array, "*a dim"],
vector / jnp.linalg.norm(vector),
)
assert isinstance(result, Float[Array, "dim"])
assert isinstance(result, Float[Array, "*a dim"])

return result

Expand Down Expand Up @@ -191,13 +191,13 @@ def to_cartesian(
When last component is 0, this function just discard the w-component
without division.
"""
result: Float[Array, "*batch dim-1"] = jnp.where( # pyright: ignore
result: Float[Array, "*batch dim-1"]
result = jnp.where( # pyright: ignore[reportUnknownMemberType]
# if w component is 0, just discard it and return.
coordinates[..., -1:] == 0.0,
coordinates[..., :-1],
normalise_homogeneous(coordinates)[..., :-1],
)
assert isinstance(result, Float[Array, "*batch dim-1"])

return result

Expand Down
10 changes: 3 additions & 7 deletions renderer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
from jax.tree_util import tree_map
from jaxtyping import Array, Bool, Float, Integer, Num, Shaped
from jaxtyping import PyTree # pyright: ignore[reportUnknownVariableType]
from jaxtyping import jaxtyped # pyright: ignore[reportUnknownVariableType]

from ._backport import List, NamedTuple, Sequence, Tuple, TypeAlias
Expand Down Expand Up @@ -268,11 +267,8 @@ def merge_maps(maps: MapsT) -> Tuple[MapT, Tuple[int, int]]:
# TODO: find a better way to merge maps
with jax.ensure_compile_time_eval():
dims: int = len(maps[0].shape)
shapes: PyTree[Tuple[int, ...], ...]
shapes = cast( # pyright: ignore[reportUnknownVariableType]
PyTree[Tuple[int, ...], ...],
tree_map(lambda m: m.shape, maps),
)
shapes: Sequence[Tuple[int, ...]]
shapes = tree_map(lambda m: m.shape, maps)
# pick the largest shape for each dimension
single_shape: Tuple[int, ...] = cast(
Tuple[int, ...],
Expand Down Expand Up @@ -485,7 +481,7 @@ def merge_objects(objects: Sequence[ModelObject]) -> MergedModel:
diffuse_map, single_map_shape = MergedModel.merge_maps(
[m.diffuse_map for m in models]
)
specular_map = cast(MapT, MergedModel.merge_maps([m.specular_map for m in models]))
specular_map = MergedModel.merge_maps([m.specular_map for m in models])[0]

@jaxtyped
@partial(jit, inline=True)
Expand Down

0 comments on commit 55999f6

Please sign in to comment.