-
Notifications
You must be signed in to change notification settings - Fork 2
/
cider.py
executable file
·92 lines (74 loc) · 2.82 KB
/
cider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
from collections import defaultdict
from argparse import ArgumentParser
from PIL import Image
import torch
import language_evaluation
import pickle
def readJSON(file_path):
try:
with open(file_path) as f:
data = json.load(f)
return data
except Exception as e:
raise Exception(f"Error reading json file: {e}")
def readPickle(file_path):
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
return data
except:
return None
def getGTCaptions(annotations):
video_name_to_gts = defaultdict(list)
for item in annotations:
video_name = item['video_name']
output_sentence = item['revised_label']
# output_sentence = item['labels']
video_name_to_gts[video_name] = output_sentence
return video_name_to_gts
class BLEUScore:
def __init__(self):
self.evaluator = language_evaluation.CocoEvaluator(coco_types=["BLEU"])
def __call__(self, predictions, gts):
predicts = []
answers = []
for img_name in predictions.keys():
predicts.append(predictions[img_name])
answers.append(gts[img_name][0] if isinstance(gts[img_name], list) else gts[img_name])
results = self.evaluator.run_evaluation(predicts, answers)
return results
class CIDERScore:
def __init__(self):
self.evaluator = language_evaluation.CocoEvaluator(coco_types=["CIDEr"])
def __call__(self, predictions, gts):
predicts = []
answers = []
for img_name in predictions.keys():
predicts.append(predictions[img_name])
#answers.append(gts[img_name])
answers.append(gts[img_name][0] if isinstance(gts[img_name], list) else gts[img_name])
results = self.evaluator.run_evaluation(predicts, answers)
return results['CIDEr']
def main(args):
# Read data
predictions = readJSON(args.pred_file)
annotations = readPickle(args.annotation_file)
# Preprocess annotation file
gts = getGTCaptions(annotations)
# Check predictions content is correct
assert type(predictions) is dict
assert set(predictions.keys()) == set(gts.keys())
assert all([type(pred) is str for pred in predictions.values()])
# CIDErScore
cider_score = CIDERScore()(predictions, gts)
bleu_score = BLEUScore()(predictions, gts)
print(f"CIDEr: {cider_score}")
print(f"BLEU: {bleu_score}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--pred_file" ,default='/home/peihsin/projects/MotionExpert/results_epoch15.json', help="Prediction json file")
parser.add_argument("--annotation_file" ,default="/home/peihsin/projects/humanML/dataset/rm_test.pkl", help="Annotation json file")
args = parser.parse_args()
main(args)