-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathrun_eval.py
executable file
·102 lines (78 loc) · 3.55 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import argparse
from collections import defaultdict
import time
import torch
from torchvision.transforms import Normalize
from torch.utils.data import DataLoader
from tqdm import tqdm
from arguments import eval_parser
from model import GraphSuperResolutionNet
from data import MiddleburyDataset, NYUv2Dataset, DIMLDataset
from utils import to_cuda
class Evaluator:
def __init__(self, args: argparse.Namespace):
self.args = args
self.dataloader = self.get_dataloader(args)
self.model = GraphSuperResolutionNet(args.scaling, args.crop_size, args.feature_extractor)
self.resume(path=args.checkpoint)
self.model.cuda().eval()
torch.set_grad_enabled(False)
def evaluate(self):
test_stats = defaultdict(float)
for sample in tqdm(self.dataloader, leave=False):
sample = to_cuda(sample)
output = self.model(sample)
_, loss_dict = self.model.get_loss(output, sample)
for key in loss_dict:
test_stats[key] += loss_dict[key]
return {k: v / len(self.dataloader) for k, v in test_stats.items()}
@staticmethod
def get_dataloader(args: argparse.Namespace):
data_args = {
'crop_size': (args.crop_size, args.crop_size),
'in_memory': args.in_memory,
'max_rotation_angle': 0,
'do_horizontal_flip': False,
'crop_valid': True,
'crop_deterministic': True,
'image_transform': Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
'scaling': args.scaling
}
if args.dataset == 'DIML':
depth_transform = Normalize([2749.64], [1154.29])
dataset = DIMLDataset(os.path.join(args.data_dir, 'DIML'), **data_args, split='test',
depth_transform=depth_transform)
elif args.dataset == 'Middlebury':
depth_transform = Normalize([2296.78], [1122.7])
dataset = MiddleburyDataset(os.path.join(args.data_dir, 'Middlebury'), **data_args, split='test',
depth_transform=depth_transform)
elif args.dataset == 'NYUv2':
depth_transform = Normalize([2796.32], [1386.05])
dataset = NYUv2Dataset(os.path.join(args.data_dir, 'NYU Depth v2'), **data_args, split='test',
depth_transform=depth_transform)
else:
raise NotImplementedError(f'Dataset {args.dataset}')
return DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False)
def resume(self, path):
if not os.path.isfile(path):
raise RuntimeError(f'No checkpoint found at \'{path}\'')
checkpoint = torch.load(path)
if 'model' in checkpoint:
self.model.load_state_dict(checkpoint['model'])
else:
self.model.load_state_dict(checkpoint)
print(f'Checkpoint \'{path}\' loaded.')
if __name__ == '__main__':
args = eval_parser.parse_args()
print(eval_parser.format_values())
evaluator = Evaluator(args)
since = time.time()
stats = evaluator.evaluate()
time_elapsed = time.time() - since
# de-standardize losses and convert to cm (cm^2, respectively)
std = evaluator.dataloader.dataset.depth_transform.std[0]
stats['l1_loss'] = 0.1 * std * stats['l1_loss']
stats['mse_loss'] = 0.01 * std**2 * stats['mse_loss']
print('Evaluation completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print(stats)