Skip to content

Commit

Permalink
Merge branch '0.2.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyTeng committed Jun 3, 2023
2 parents 890c9b8 + 7a71323 commit 2a7cf39
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 82 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
- Rendering pipeline is different. PyTinyrenderer renders one object at a time, and share zbuffer and framebuffer across one pass. This renderer first merges all objects into one big mesh in world space, then process all vertices together, then interpolates and rasterise and render. For fragment shading, this is done by sweeping each row in a for loop, and batch compute all pixels together. For computing a pixel, all fragments for that pixels are batch compute together, then mixed. This is more memory efficient and allows `vmap` batching as far as possible.
- Shadowing within the same object / mesh is allowed. This is not possible in PyTinyrenderer, as it deliberately checks if the shadow comes from the same object; if so, it will not consider to draw a shadow there.
- Quaternion (for specifying rotation/orientation) is in the form of `(w, x, y, z)` instead of `(x, y, z, w)` in PyTinyrenderer. This is for consistency with `BRAX`.
- No clipping is performed. To ensure correct rendering of objects with vertices at or behind camera plane, homogeneous interpolation (Olano and Greer, 1997)[^1] is used to avoid the need of homogeneous division.
- Fix bugs
- Specular lighting was wrong, where it forgets to reverse the light direction vector.

[^1]: Marc Olano and Trey Greer. 1997. Triangle Scan Conversion Using 2D Homogeneous Coordinates. In _Proceedings of the ACM SIGGRAPH/EUROGRAPHICS Workshop on Graphics Hardware (HWWS ’97)_. ACM, New York, NY, USA, 89–95.

## Roadmap

- [ ] Correctly implement a proper clipping algorithm
- [ ] Support double-sided objects
- [ ] Profile and accelerate implementation
- [ ] Differentiable rendering
- [ ] Build a ray tracer as well
- [ ] <s>Correctly implement a proper clipping algorithm</s>
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@
1. Correctly force convert `LightParameters` to JAX arrays in `Renderer.get_camera_image` to avoid downstream errors.
2. Fix `geometry.py::transform_matrix_from_rotation`. Also, change the order of quaternion to `(w, x, y, z)` instead of `(x, y, z, w)` for consistency.
3. Force convert `ShadowParameters` to JAX arrays in `Renderer.get_camera_image` to avoid downstream errors.

## 0.2.0

1. Instead of clipping (planned to be implemented), now the rasteriser interpolates in homogeneous space directly. `Shader.interpolate` will not receive valid `barycentric_screen` values for now. Setting `Interpolation.SMOOTH` and `Interpolation.NOPERSPECTIVE` will result in same results, perspective-correct interpolations.
2. Reorganise example files and rename them.
File renamed without changes.
File renamed without changes.
108 changes: 108 additions & 0 deletions examples/behind_camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import jax
import jax.lax as lax
import jax.numpy as jnp

from renderer import (CameraParameters, LightParameters, Renderer, Scene,
ShadowParameters, Texture, UpAxis, transpose_for_display,
build_texture_from_PyTinyrenderer)

# PROCESS: Set up models and objects

scene: Scene = Scene()
texture: Texture = jnp.array([
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 1., 0.],
]).reshape((2, 2, 3))

scene, cube_model_1 = scene.add_cube(
half_extents=(1., 1., 0.03),
diffuse_map=texture,
texture_scaling=(16., 16.),
)
scene, cube_1 = scene.add_object_instance(cube_model_1)
scene = scene.set_object_position(cube_1, (0., 0., 0.))
scene = scene.set_object_orientation(cube_1, (1., 0., 0., 0.))

scene, cube_model_2 = scene.add_cube(
half_extents=(10., 10., 0.03),
diffuse_map=texture,
texture_scaling=(160., 160.),
)
scene, cube_2 = scene.add_object_instance(cube_model_2)
scene = scene.set_object_position(cube_2, (0., 0., 0.))
scene = scene.set_object_orientation(cube_2, (1., 0., 0., 0.))

# PROCESS: Set up camera and light

width = 640
height = 480
eye = jnp.asarray([2.5894797, -2.5876467, 1.9174135])
target = [0., 0., 0.]

light: LightParameters = LightParameters()
camera: CameraParameters = CameraParameters(
viewWidth=width,
viewHeight=height,
position=eye,
target=target,
hfov=58.0,
vfov=32.625,
)
shadow_param = ShadowParameters()

# PROCESS: Render

images = []

img = Renderer.get_camera_image(
objects=[scene.objects[cube_1]],
light=light,
camera=camera,
width=width,
height=height,
shadow_param=shadow_param,
)
rgb_array = lax.clamp(0., img * 255, 255.).astype(jnp.uint8)
images.append(rgb_array)

img = Renderer.get_camera_image(
objects=[scene.objects[cube_2]],
light=light,
camera=camera,
width=width,
height=height,
shadow_param=shadow_param,
)
rgb_array = lax.clamp(0., img * 255, 255.).astype(jnp.uint8)
images.append(rgb_array)

# PROCESS: show

import matplotlib.animation as animation
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

# ims is a list of lists, each row is a list of artists to draw in the
# current frame; here we are just animating one artist, the image, in
# each frame
ims = []
for i, img in enumerate(images):
im = ax.imshow(transpose_for_display(img), animated=True)
if i == 0:
# show an initial one first
ax.imshow(transpose_for_display(img))

ims.append([im])

ani = animation.ArtistAnimation(
fig,
ims,
interval=500,
blit=True,
repeat_delay=0,
)

plt.show()
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxrenderer"
version = "0.1.3"
version = "0.2.0"
description = "Jax implementation of rasterizer renderer."
authors = ["Joey Teng <[email protected]>"]
license = "Apache-2.0"
Expand Down
Loading

0 comments on commit 2a7cf39

Please sign in to comment.