Skip to content

Commit

Permalink
Merge branch '0.1.3'
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyTeng committed Jun 2, 2023
2 parents 1f33253 + d5ca051 commit 890c9b8
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Texture specification is different: in PyTinyrenderer, the texture is specified in a flattened array, while in JAX Renderer, the texture is specified in a shape of (width, height, colour channels). A simple way to transform old specification to new specification is to use the convenient method `build_texture_from_PyTinyrenderer`.
- Rendering pipeline is different. PyTinyrenderer renders one object at a time, and share zbuffer and framebuffer across one pass. This renderer first merges all objects into one big mesh in world space, then process all vertices together, then interpolates and rasterise and render. For fragment shading, this is done by sweeping each row in a for loop, and batch compute all pixels together. For computing a pixel, all fragments for that pixels are batch compute together, then mixed. This is more memory efficient and allows `vmap` batching as far as possible.
- Shadowing within the same object / mesh is allowed. This is not possible in PyTinyrenderer, as it deliberately checks if the shadow comes from the same object; if so, it will not consider to draw a shadow there.
- Quaternion (for specifying rotation/orientation) is in the form of `(w, x, y, z)` instead of `(x, y, z, w)` in PyTinyrenderer. This is for consistency with `BRAX`.
- Fix bugs
- Specular lighting was wrong, where it forgets to reverse the light direction vector.

Expand Down
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@
3. Changed the way that `Camera` is created in `Renderer.create_camera_from_parameters` to force convert parameters into `float` weak type.
4. Force convert `LightParameters` to JAX arrays in `Renderer.get_camera_image` to avoid downstream errors.
5. Downgrade minimum Python version to `3.9`, `numpy` version to `1.22.0`, `jax` and `jaxlib` version to `0.4.4`.

## 0.1.3

1. Correctly force convert `LightParameters` to JAX arrays in `Renderer.get_camera_image` to avoid downstream errors.
2. Fix `geometry.py::transform_matrix_from_rotation`. Also, change the order of quaternion to `(w, x, y, z)` instead of `(x, y, z, w)` for consistency.
3. Force convert `ShadowParameters` to JAX arrays in `Renderer.get_camera_image` to avoid downstream errors.
2 changes: 1 addition & 1 deletion example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
scene, plane_instance_id = scene.add_object_instance(plane_model)
scene = scene.set_object_orientation(
plane_instance_id,
(0, 0, 0, 1.),
(1., 0, 0, 0),
)

img = Renderer.get_camera_image(
Expand Down
19 changes: 14 additions & 5 deletions example4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from renderer import (CameraParameters, LightParameters, Renderer, Scene,
ShadowParameters, Texture, UpAxis, batch_models,
merge_objects, rotation_matrix, transpose_for_display)
merge_objects, quaternion, rotation_matrix,
transpose_for_display)

# PROCESS: Set up models and objects

Expand All @@ -29,10 +30,18 @@
scene, capsule_obj_id = scene.add_object_instance(capsule_id)
capsule_obj_ids.append(capsule_obj_id)

scene = scene.set_object_orientation(
capsule_obj_id,
rotation_matrix=rotation_matrix((0., 1., 0.), 30 * (i - 6)),
)
# to try both ways of setting orientation
if i < 6:
scene = scene.set_object_orientation(
capsule_obj_id,
# rotation_matrix=rotation_matrix((0., 1., 0.), 30 * (i - 6)),
orientation=quaternion((0., 1., 0.), 30 * (i - 6)),
)
else:
scene = scene.set_object_orientation(
capsule_obj_id,
rotation_matrix=rotation_matrix((0., 1., 0.), 30 * (i - 6)),
)

# PROCESS: Set up camera and light

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxrenderer"
version = "0.1.2"
version = "0.1.3"
description = "Jax implementation of rasterizer renderer."
authors = ["Joey Teng <[email protected]>"]
license = "Apache-2.0"
Expand Down
10 changes: 4 additions & 6 deletions renderer/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def rotation_matrix(
def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]:
"""Generate a transform matrix from a quaternion rotation.
Supports non-unit rotation.
Quaternion is specified in (w, x, y, z) order. Supports non-unit rotation.
References:
- [Quaternions and spatial rotation#Quaternion-derived rotation matrix](https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix)
Expand All @@ -921,11 +921,9 @@ def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]:
d = rotation @ rotation
s = 2.0 / d # here s is $2\times s$ in Wikipedia.

rs: Vec3f = rotation[:3] * s
((wx, wy, wz), (xx, xy, xz), (yy, yz, zz)) = jnp.outer(
rotation[jnp.array((3, 0, 1))],
rs,
)
rs: Vec3f = rotation[1:] * s # x y z
((wx, wy, wz), (xx, xy, xz), (_, yy, yz)) = jnp.outer(rotation[:3], rs)
zz = rotation[3] * rs[2]

mat: Float[Array, "3 3"] = jnp.array((
(1. - (yy + zz), xy - wz, xz + wy),
Expand Down
15 changes: 14 additions & 1 deletion renderer/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,12 @@ def get_camera_image(

assert isinstance(_camera, Camera), f"{_camera}"

light = tree_map(jnp.asarray, light)
light = tree_map(
jnp.asarray,
light,
# only flatten one layer
is_leaf=lambda x: not isinstance(x, LightParameters),
)
assert isinstance(light, LightParameters), f"{light}"

buffers: Buffers = cls.create_buffers(
Expand All @@ -343,6 +348,14 @@ def get_camera_image(
model: MergedModel = merge_objects(objects)
assert isinstance(model, MergedModel), f"{model}"

if shadow_param is not None:
shadow_param = tree_map(
jnp.asarray,
shadow_param,
# only flatten one layer
is_leaf=lambda x: not isinstance(x, ShadowParameters),
)

canvas: Canvas
_, (canvas, ) = cls.render(
model=model,
Expand Down

0 comments on commit 890c9b8

Please sign in to comment.