forked from ohhhyeahhh/PointAttN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_c3d.py
80 lines (64 loc) · 2.81 KB
/
test_c3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
import os
import sys
import importlib
import argparse
import munch
import yaml
from utils.train_utils import *
from dataset import C3D_h5
import h5py
def save_h5(data, path):
f = h5py.File(path, 'w')
a = data.data.cpu().numpy()
print(a.shape)
f.create_dataset('data', data=a)
f.close()
def save_obj(point, path):
n = point.shape[0]
with open(path, 'w') as f:
for i in range(n):
f.write("v {0} {1} {2}\n".format(point[i][0],point[i][1],point[i][2]))
f.close()
def test():
dataset_test = C3D_h5(args.c3dpath, prefix="test")
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size,
shuffle=False, num_workers=1)
dataset_length = len(dataset_test)
logging.info('Length of test dataset:%d', len(dataset_test))
# load model
model_module = importlib.import_module('.%s' % args.model_name, 'models')
net = torch.nn.DataParallel(model_module.Model(args))
net.cuda()
net.module.load_state_dict(torch.load(args.load_model)['net_state_dict'])
logging.info("%s's previous weights loaded." % args.model_name)
net.eval()
with torch.no_grad():
for i, data in enumerate(dataloader_test):
label, inputs_cpu, gt_cpu = data
inputs = inputs_cpu.float().cuda()
gt = gt_cpu.float().cuda()
inputs = inputs.transpose(2, 1).contiguous()
result_dict = net(inputs, gt, is_training=False)
if i % args.step_interval_to_print == 0:
logging.info('test [%d/%d]' % (i, dataset_length / args.batch_size))
if args.save_vis:
if not os.path.isdir(os.path.join(os.path.dirname(args.load_model), 'all')):
os.makedirs(os.path.join(os.path.dirname(args.load_model), 'all'))
for j in range(args.batch_size):
path = os.path.join(os.path.dirname(args.load_model), 'all', str(label[j]))
save_h5(result_dict['out2'][j], path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test config file')
parser.add_argument('-c', '--config', help='path to config file', required=True)
arg = parser.parse_args()
config_path = os.path.join('./cfgs',arg.config)
args = munch.munchify(yaml.safe_load(open(config_path)))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
if not args.load_model:
raise ValueError('Model path must be provided to load model!')
exp_name = os.path.basename(args.load_model)
log_dir = os.path.dirname(args.load_model)
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(log_dir, 'test.log')),
logging.StreamHandler(sys.stdout)])
test()