-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutil.py
96 lines (68 loc) · 2.97 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: Yue Wang
@Contact: [email protected]
@File: util
@Time: 4/5/19 3:47 PM
"""
import numpy as np
import torch
import torch.nn.functional as F
from config import config as cfg
def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(pred, gold, reduction='mean')
return loss
class IOStream():
def __init__(self, path):
self.f = open(path, 'a')
def cprint(self, text):
print(text)
self.f.write(text+'\n')
self.f.flush()
def close(self):
self.f.close()
from typing import Dict, List, Optional, Tuple, Callable
import torch
import torch.nn as nn
from pytorch3d.transforms import *
class Tooth_Assembler(nn.Module):
def __init__(self):
super(Tooth_Assembler, self).__init__()
def forward(self,pred: torch.Tensor, cenp: torch.Tensor, dofs: torch.Tensor, ptrans: torch.Tensor, device: torch.device) -> torch.Tensor:
assembled = torch.zeros(size=pred.shape, device=device)
pred_matrices = torch.cat([quaternion_to_matrix(dofs[idx]).unsqueeze(0) for idx in range(dofs.shape[0])], dim=0)
# arch_points = rcpoints.view(rcpoints.shape[0], 1, 1, cfg.dim)
# pred = pred + arch_points
pred_matrices_numpy = pred_matrices.detach().cpu().numpy()
for idx in range(pred.shape[0]): # X_v: 8,28,512,3; matrices: B,28,4,4
centerp = cenp[idx, :, :, :]#torch.mean(pred[idx, :, :, :],dim=1, keepdim=True)
points = pred[idx, :, :, :] - centerp
transv = ptrans[idx, :, :].unsqueeze(1)
points = torch.bmm(points, pred_matrices[idx, :, :, :])
assembled[idx, :, :, :] = points + transv
assembled[idx, :, :, :] = assembled[idx, :, :, :] + centerp
# for tid in range(1):
# data2 = assembled[tid, :, :, :].detach().cpu().numpy().reshape(cfg.teeth_nums*cfg.sam_points, 3)
# data1 = tag[tid, :, :, :].detach().cpu().numpy().reshape(cfg.teeth_nums*cfg.sam_points, 3)
# file_2 = open('./outputs/rpointxt' + str(tid) + '.txt', "w")
# file_1 = open('./outputs/tag' + str(tid) + '.txt', "w")
#
# for i in range(data2.shape[0]):
# file_2.write(str(data2[i][0]) + " " + str(data2[i][1]) + " " + str(data2[i][2]) + "\n")
# file_2.close()
#
# for i in range(data1.shape[0]):
# file_1.write(str(data1[i][0]) + " " + str(data1[i][1]) + " " + str(data1[i][2]) + "\n")
# file_1.close()
return assembled