-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
92 lines (78 loc) · 3.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
class BlendShape(nn.Module):
def __init__(self, base, indices, identities,
expressions, is_offset=False) -> None:
super().__init__()
self.base = base # (V, 3)
self.base_unsqueezed = self.base.unsqueeze(0)
self.indices = indices # (VI, 3)
self.identities = identities # (I, V, 3)
self.expressions = expressions # (E, V, 3)
if not is_offset:
self.identities = identities - self.base_unsqueezed
self.expressions = expressions - self.base_unsqueezed
# print(self.identities[0], identities[0], self.base[0])
# print(self.identities.shape, identities.shape, self.base.shape)
self.identity_coeffs = None
self.expression_coeffs = None
self.morphed = None
'''
def to(self, device):
# Manually move to members as they are not a subclass of nn.Module
self.cameras = self.cameras.to(device)
return self
'''
def forward(self, identity_coeffs, expression_coeffs) -> torch.Tensor:
self.identity_coeffs = identity_coeffs.reshape(
identity_coeffs.shape[0], 1, 1) # (I, 1, 1)
self.expression_coeffs = expression_coeffs.reshape(
expression_coeffs.shape[0], 1, 1) # (E, 1, 1)
self.morphed = (
self.base
+ torch.sum(self.identity_coeffs * self.identities, dim=0)
+ torch.sum(self.expression_coeffs * self.expressions, dim=0)
)
return self.morphed
class OrthoCamera(nn.Module):
def __init__(self, device) -> None:
super().__init__()
self.scale = 1.0
self.w2c_q = torch.tensor([1.0, .0, .0, .0], device=device)
self.w2c_R = quaternion_to_matrix(self.w2c_q)
self.w2c_t = torch.zeros((3,), device=device)
self.points2d = None
def forward(self, points) -> torch.Tensor:
# Ensure unit quartanion
self.w2c_q_norm = self.w2c_q / torch.linalg.norm(self.w2c_q)
self.w2c_R = quaternion_to_matrix(self.w2c_q_norm)
self.points2d = torch.t(self.w2c_R @
torch.t(points * self.scale)) + self.w2c_t
# self.points2d = points * self.scale + self.w2c_t
return self.points2d