-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathopticalflow_evaluate.py
executable file
·92 lines (78 loc) · 3.23 KB
/
opticalflow_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ---------------------------------------------------------------------
# Copyright (c) 2018 TU Berlin, Communication Systems Group
# Written by Tobias Senst <[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# ---------------------------------------------------------------------
import file_parser as fp
import cv2
import sys
import pickle
import copy
import util as ut
import progressbar
def run_parameter(config_item):
result = dict()
result["ee"] = 0
result["R1"] = 0
result["R2"] = 0
result["R3"] = 0
result["no_points"] = 0
# load ground truth optical flow
flow_gt = ut.readFlowFiles(config_item["files"]["gtflow"])
# load ground truth mask indicating foreground and background flow vectors
mask_rgb = cv2.imread(config_item["files"]["mask"])
# load estimated optical flow
est_flow = ut.readFlowFiles(config_item["files"]["estflow"])
# compute short term errors
result = ut.compute_error(est_flow, flow_gt, mask_rgb)
return (config_item, result)
def main():
if len(sys.argv) < 2:
print("Please provide the first argument that is the root path of the CrowdFlow dataset.")
return
basepath = sys.argv[1]
parameter_list = list()
for n in range(2, len(sys.argv)):
parameter_list.append({"flow_method": "/" + sys.argv[n] + "/" })
latex_filename = "short_term_results.tex"
result_filename = "short_term_results.pb"
result_list = list()
bar = progressbar.ProgressBar()
for n, parameter in enumerate(parameter_list):
basepath_dict = {"basepath": basepath,
"images" : basepath + "/images/",
"groundtruth": basepath + "/gt_flow/",
"estimate": basepath + "/estimate/" + parameter["flow_method"] + "/",
"masks": basepath + "/masks/",
}
filenames = fp.create_filename_list(basepath_dict)
config_list = ut.create_config(parameter, filenames)
bar.start(max_value=len(config_list) * len(parameter_list))
for c, config_item in enumerate(config_list):
result = run_parameter(config_item)
result_list.append(copy.deepcopy(result))
bar.update(c + n * len(config_list))
bar.finish()
print("\n")
print("\n")
print("Save short term evaluation file ", result_filename)
out_dict = {"result": result_list}
pickle.dump(out_dict, open( result_filename, "wb"))
result_str = ut.getLatexTable(result_filename)
print(result_str)
if len(latex_filename) > 0:
with open(latex_filename, "w") as f:
f.write(result_str)
main()