Skip to content

Commit

Permalink
Merge branch '0.1.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyTeng committed May 31, 2023
2 parents 22d608f + b0d0411 commit 655b9b7
Show file tree
Hide file tree
Showing 14 changed files with 524 additions and 216 deletions.
14 changes: 14 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,17 @@
7. Fix issue of perspective-correction barycentric interpolation in `pipeline`.
8. Fix `shaders/phong*` so normals are correctly transformed into pre-projection eye coordinates, rather than being projected.
9. Rename `ModelView` => `View`, `model_view` => `view`, `model_view_matrix` => `view_matrix` as the matrix is actually view matrix that transforms from world to eye space, not model view matrix (model to eye space).

## 0.1.1

1. Change the default behaviour of `renderer/utils.py::transpose_for_display` which will flip vertically as well by default, so the origin of the resultant matrix will be (height, width, channels) and with the origin located at the top-left corner. The previous behaviour can be achieved by setting `flip_vertical=False`.
2. `Scene.add_cube` now accepts one number for `texture_scaling` to scale texture map equally in both x and y directions.
3. Fix some assert message issues (in `Scene.add_cube`).
4. `CameraParameters` now accepts `position`, `target` and `up` in Python's tuples of floats as well, along with `jnp.array`.
5. `Scene.set_object_orientation` and `Scene.set_object_local_scaling` supports tuple of floats as well as inputs, additional to `jnp.array`.
6. `Model` now has a convenient method `create` to create a Model with same face indices shared by `faces`, `faces_norm` and `faces_uv`, and a default `specular_map`. This is useful for creating a mesh where all vertices has its own normal and uv coordinate specified, under same order (thus same face indices).
7. Correctly support Python Sequence for `utils.build_texture_from_PyTinyrenderer` as texture.
8. `quaternion` function to create an orientation from axis and angle, and `quaternion_mul` to composite quaternion.
9. `rotation_matrix` function to create a rotation matrix from axis and angle. Also allows `Scene` to set object orientation directly using rotation matrix.
10. Move `Renderer.merge_objects` into `geometry.py`, and expose in `__init__.py`.
11. `batch_models` and `Renderer.create_buffers` convenient method to facilitate batch rendering of multiple models.
16 changes: 6 additions & 10 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax.numpy as jnp

from renderer import (CameraParameters, LightParameters, Renderer, Scene,
ShadowParameters, Texture, UpAxis, Vec3f,
transpose_for_display, build_texture_from_PyTinyrenderer)
ShadowParameters, Texture, UpAxis, transpose_for_display,
build_texture_from_PyTinyrenderer)

# PROCESS: Set up models and objects

Expand Down Expand Up @@ -66,8 +66,8 @@

width = 640
height = 480
eye: Vec3f = jnp.array([2., 4., 1.])
target: Vec3f = jnp.array([0., 0., 0.])
eye = [2., 4., 1.]
target = [0., 0., 0.]

light: LightParameters = LightParameters()
camera: CameraParameters = CameraParameters(
Expand Down Expand Up @@ -163,14 +163,10 @@
# each frame
ims = []
for i, img in enumerate(images):
im = ax.imshow(
transpose_for_display(img),
origin='lower',
animated=True,
)
im = ax.imshow(transpose_for_display(img), animated=True)
if i == 0:
# show an initial one first
ax.imshow(transpose_for_display(img), origin='lower')
ax.imshow(transpose_for_display(img))

ims.append([im])

Expand Down
18 changes: 7 additions & 11 deletions example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

width = 640
height = 480
eye: Vec3f = jnp.array([2., 4., 1.])
target: Vec3f = jnp.array([0., 0., 0.])
eye: Vec3f = (2., 4., 1.)
target: Vec3f = (0., 0., 0.)

light: LightParameters = LightParameters()
camera: CameraParameters = CameraParameters(
Expand All @@ -22,7 +22,7 @@
)

texture: Texture = build_texture_from_PyTinyrenderer(
jnp.array((
(
255,
255,
255, # White
Expand All @@ -35,7 +35,7 @@
0,
0,
255 # Blue
)),
),
2,
2,
) / 255.0
Expand All @@ -62,23 +62,19 @@
])

indices = jnp.array([[0, 1, 2], [0, 2, 3]])
model: Model = Model(
model: Model = Model.create(
verts=vertices,
norms=normals,
uvs=uvs,
faces=indices,
faces_norm=indices,
faces_uv=indices,
diffuse_map=texture,
# reference: https://github.com/erwincoumans/tinyrenderer/blob/89e8adafb35ecf5134e7b17b71b0f825939dc6d9/model.cpp#L215
specular_map=lax.full(texture.shape[:2], 2.0),
)
scene, plane_model = scene.add_model(model)

scene, plane_instance_id = scene.add_object_instance(plane_model)
scene = scene.set_object_orientation(
plane_instance_id,
jnp.array([0, 0, 0, 1.]),
(0, 0, 0, 1.),
)

img = Renderer.get_camera_image(
Expand All @@ -95,6 +91,6 @@
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.imshow(transpose_for_display(rgb_array), origin='lower')
ax.imshow(transpose_for_display(rgb_array))

plt.show()
2 changes: 1 addition & 1 deletion example3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.imshow(transpose_for_display(rgb_array), origin='lower')
ax.imshow(transpose_for_display(rgb_array))

plt.show()
98 changes: 98 additions & 0 deletions example4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Example: Batch rendering a 12-frame animation of a rotating capsule."""

import jax
import jax.numpy as jnp

from renderer import (CameraParameters, LightParameters, Renderer, Scene,
ShadowParameters, Texture, UpAxis, batch_models,
merge_objects, rotation_matrix, transpose_for_display)

# PROCESS: Set up models and objects

scene: Scene = Scene()
texture: Texture = jnp.array((
(255, 255, 255), # White
(255, 0, 0), # Red
(0, 255, 0), # Green
(0, 0, 255), # Blue
)).reshape((2, 2, 3))[:, ::-1, :] / 255.0

scene, capsule_id = scene.add_capsule(
radius=0.1,
half_height=0.4,
up_axis=UpAxis.Z,
diffuse_map=texture,
)

capsule_obj_ids = []
for i in range(12):
scene, capsule_obj_id = scene.add_object_instance(capsule_id)
capsule_obj_ids.append(capsule_obj_id)

scene = scene.set_object_orientation(
capsule_obj_id,
rotation_matrix=rotation_matrix((0., 1., 0.), 30 * (i - 6)),
)

# PROCESS: Set up camera and light

width = 640
height = 480

eye = [0., 4., 0.]
target = [0., 0., 0.]

light: LightParameters = LightParameters()
camera_params: CameraParameters = CameraParameters(
viewWidth=width,
viewHeight=height,
position=eye,
target=target,
)
shadow_param = ShadowParameters()

# PROCESS: Render

merged_models = [
merge_objects([scene.objects[obj_id]]) for obj_id in capsule_obj_ids
]
buffers = Renderer.create_buffers(width, height, len(capsule_obj_ids))
camera = Renderer.create_camera_from_parameters(camera_params)

_, (images, ) = jax.vmap(lambda model, buffer: Renderer.render(
model=model,
light=light,
camera=camera,
buffers=buffer,
shadow_param=shadow_param,
))(batch_models(merged_models), buffers)

# PROCESS: show

import matplotlib.animation as animation
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

# ims is a list of lists, each row is a list of artists to draw in the
# current frame; here we are just animating one artist, the image, in
# each frame
ims = []
for i in range(images.shape[0]):
img = images[i]
im = ax.imshow(transpose_for_display(img), animated=True)
if i == 0:
# show an initial one first
ax.imshow(transpose_for_display(img))

ims.append([im])

ani = animation.ArtistAnimation(
fig,
ims,
interval=500,
blit=True,
repeat_delay=0,
)

plt.show()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxrenderer"
version = "0.1.0"
version = "0.1.1"
description = "Jax implementation of rasterizer renderer."
authors = ["Joey Teng <[email protected]>"]
license = "Apache-2.0"
Expand Down
5 changes: 3 additions & 2 deletions renderer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .geometry import Camera, normalise
from .model import Model, ModelObject
from .geometry import (Camera, normalise, quaternion, quaternion_mul,
rotation_matrix)
from .model import Model, ModelObject, batch_models, merge_objects
from .renderer import (CameraParameters, LightParameters, Renderer,
ShadowParameters)
from .scene import Scene, UpAxis
Expand Down
90 changes: 89 additions & 1 deletion renderer/geometry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
from functools import partial
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, Union

import jax
import jax.lax as lax
Expand Down Expand Up @@ -822,6 +822,94 @@ def compute_normals(batch_verts: Float[Array, "b 3 3"]) -> Float[Array, "b 3"]:
return jax.vmap(compute_normal)(batch_verts)


@jaxtyped
@jax.jit
def quaternion(
rotation_axis: Union[Vec3f, tuple[float, float, float]],
rotation_angle: Union[Float[Array, ""], float],
) -> Vec4f:
"""Generate a quaternion rotation from a rotation axis and angle.
The rotation axis is normalised internally. The angle is specified in
degrees (NOT radian). The rotation is clockwise.
"""
axis = normalise(jnp.asarray(rotation_axis))
angle = jnp.radians(jnp.asarray(rotation_angle))
assert isinstance(axis, Vec3f), f"{rotation_axis}"
assert isinstance(angle, Float[Array, ""]), f"{rotation_angle}"

quaternion: Vec4f = (
jnp.zeros(4) #
.at[:3].set(axis * jnp.sin(angle / 2)) #
.at[3].set(jnp.cos(angle / 2)) #
)

return quaternion


@jaxtyped
@jax.jit
def quaternion_mul(quatA: Vec4f, quatB: Vec4f) -> Vec4f:
"""Multiply two quaternion rotations, as to composite them.
Noticed that all quaternions here are in order of (x, y, z, w), thus
shuffles are used here to convert from (w, x, y, z) to (x, y, z, w).
References:
- [Quaternion multiplication](https://www.mathworks.com/help/nav/ref/quaternion.mtimes.html)
"""
assert isinstance(quatA, Vec4f)
assert isinstance(quatB, Vec4f)

with jax.ensure_compile_time_eval():
shuffle = jnp.array((3, 0, 1, 2))
idx103 = jnp.array((1, 0, 3))
idx230 = jnp.array((2, 3, 0))
idx013 = jnp.array((0, 1, 3))
idx320 = jnp.array((3, 2, 0))

quatA = quatA[shuffle]
quatB = quatB[shuffle]

return jnp.array((
quatA[0] * quatB[0] - quatA[1:] @ quatB[1:],
quatA[:3] @ quatB[idx103] - quatA[3] * quatB[2],
quatA[idx230] @ quatB[:3] - quatA[1] * quatB[3],
quatA[idx013] @ quatB[idx320] - quatA[2] * quatB[1],
))[shuffle]


@jaxtyped
@jax.jit
def rotation_matrix(
rotation_axis: Union[Vec3f, tuple[float, float, float]],
rotation_angle: Union[Float[Array, ""], float],
) -> Float[Array, "3 3"]:
"""Generate a rotation matrix from a rotation axis and angle.
The rotation axis is normalised internally. The angle is specified in
degrees (NOT radian). The rotation is clockwise.
References:
- [glRotated](https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/glRotate.xml)
"""
axis = normalise(jnp.asarray(rotation_axis))
angle = jnp.radians(jnp.asarray(rotation_angle))
assert isinstance(axis, Vec3f), f"{rotation_axis}"
assert isinstance(angle, Float[Array, ""]), f"{rotation_angle}"

s = jnp.sin(angle)
c = jnp.cos(angle)

rotation_matrix: Float[Array, "3 3"] = (
jnp.identity(3) * c # +c at main diagonal
- jnp.sin(angle) * jnp.cross(axis, jnp.identity(3)) # second term
+ (1 - c) * jnp.outer(axis, axis) # first term
)

return rotation_matrix


@jaxtyped
@jax.jit
def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]:
Expand Down
Loading

0 comments on commit 655b9b7

Please sign in to comment.