-
Notifications
You must be signed in to change notification settings - Fork 55
/
eval.py
69 lines (51 loc) · 2.1 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch.autograd import Variable
import torch.nn as nn
import skimage
import skimage.io as io
from torchvision import transforms
import numpy as np
import scipy.io as scio
from modelNetM import EncoderNet, DecoderNet, ClassNet, EPELoss
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
model_en = EncoderNet([1,1,1,1,2])
model_de = DecoderNet([1,1,1,1,2])
model_class = ClassNet()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model_en = nn.DataParallel(model_en)
model_de = nn.DataParallel(model_de)
model_class = nn.DataParallel(model_class)
if torch.cuda.is_available():
model_en = model_en.cuda()
model_de = model_de.cuda()
model_class = model_class.cuda()
model_en.load_state_dict(torch.load('model_en.pkl'))
model_de.load_state_dict(torch.load('model_de.pkl'))
model_class.load_state_dict(torch.load('model_class.pkl'))
model_en.eval()
model_de.eval()
model_class.eval()
testImgPath = '/home/xliea/Dataset256/Dataset256/test/distorted'
saveFlowPath = '/home/xliea/test/flow_256/flow_cla'
correct = 0
for index, types in enumerate(['barrel','pincushion','rotation','shear','projective','wave']):
for k in range(50000,55000):
imgPath = '%s%s%s%s%s%s' % (testImgPath, '/',types,'_', str(k).zfill(6), '.jpg')
disimgs = io.imread(imgPath)
disimgs = transform(disimgs)
use_GPU = torch.cuda.is_available()
if use_GPU:
disimgs = disimgs.cuda()
disimgs = disimgs.view(1,3,256,256)
disimgs = Variable(disimgs)
middle = model_en(disimgs)
flow_output = model_de(middle)
clas = model_class(middle)
_, predicted = torch.max(clas.data, 1)
if predicted.cpu().numpy()[0] == index:
correct += 1
u = flow_output.data.cpu().numpy()[0][0]
v = flow_output.data.cpu().numpy()[0][1]
saveMatPath = '%s%s%s%s%s%s' % (saveFlowPath, '/',types,'_', str(k).zfill(6), '.mat')
scio.savemat(saveMatPath, {'u': u,'v': v})