-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
138 lines (118 loc) · 5.25 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import chainer
import dill
import numpy as np
from tqdm import tqdm
from johnny.dep import UDepLoader
from johnny.metrics import Average, UAS, LAS, POSAccuracy
from johnny.misc import visualise_dict
from train import dataset_to_cols, data_to_rows, to_batches
from mlconf import ArgumentParser, Blueprint
def test_loop(bp, test_set):
model_path = bp.model_path
vocab_path = bp.vocab_path
with open(vocab_path, 'rb') as pf:
vocabs = dill.load(pf)
visualise_dict(vocabs.text.index, num_items=20)
visualise_dict(vocabs.arcs.index, num_items=20)
test_data = dataset_to_cols(test_set, bp)
test_rows = data_to_rows(test_data, vocabs, bp)
print(test_set)
print(test_set[0][0])
print(test_set[-1][-1])
# Remove all info we are going to predict
# to make sure we don't make a fool of ourselves
# if we have a bug and gold data stays in its place
test_set.unset_heads()
test_set.unset_labels()
test_set.unset_deps()
test_set.unset_misc()
print(test_set[0][0])
print(test_set[-1][-1])
print('test max seq len ', test_set.len_stats['max_sent_len'])
built_bp = bp.build()
model = built_bp.model
chainer.serializers.load_npz(model_path, model)
# test
if model.predict_pos:
tf_str = ('Eval - test : batch_size={0:d}, mean loss={1:.2f}, '
'mean UAS={2:.3f} mean LAS={3:.3f} mean POS={4:.3f}')
else:
tf_str = ('Eval - test : batch_size={0:d}, mean loss={1:.2f}, '
'mean UAS={2:.3f} mean LAS={3:.3f}')
with tqdm(total=len(test_set)) as pbar, \
chainer.using_config('train', False), \
chainer.no_backprop_mode():
mean_loss = Average()
u_scorer = UAS()
l_scorer = LAS()
if model.predict_pos:
t_scorer = POSAccuracy()
index = 0
# NOTE: IMPORTANT!!
# BATCH SIZE is important here to reproduce the results
# for the cnn - since changing the batch size changes
# has the effect of different words having different padding.
# NOTE: test_mean_loss changes because it is averaged
# across batches, so changing the number of batches affects it
BATCH_SIZE = bp.batch_size # it's better to use the same batch size
for batch in to_batches(test_rows, BATCH_SIZE, sort=False):
batch_size = 0
seqs = list(zip(*batch))
pos_batch = seqs.pop() if model.predict_pos else None
label_batch = seqs.pop()
head_batch = seqs.pop()
if model.predict_pos:
arc_preds, lbl_preds, pos_preds = model(*seqs, heads=head_batch, labels=label_batch, pos_tags=pos_batch)
else:
arc_preds, lbl_preds = model(*seqs, heads=head_batch, labels=label_batch)
loss = model.loss
loss_value = float(loss.data)
for p_arcs, p_lbls, p_tags, t_arcs, t_lbls, t_tags in \
zip(arc_preds, lbl_preds, pos_preds, head_batch, label_batch, pos_batch):
u_scorer(arcs=(p_arcs, t_arcs))
l_scorer(arcs=(p_arcs, t_arcs), labels=(p_lbls, t_lbls))
if model.predict_pos:
t_scorer(tags=(p_tags, t_tags))
test_set[index].set_heads(p_arcs)
str_labels = (vocabs.arcs.rev_index[l] for l in p_lbls)
test_set[index].set_labels(str_labels)
index += 1
batch_size += 1
mean_loss(loss_value)
if model.predict_pos:
out_str = tf_str.format(batch_size, mean_loss.score, u_scorer.score, l_scorer.score, t_scorer.score)
else:
out_str = tf_str.format(batch_size, mean_loss.score, u_scorer.score, l_scorer.score)
pbar.set_description(out_str)
pbar.update(batch_size)
# make sure you aren't a dodo
assert(index == len(test_set))
stats = {'test_mean_loss': mean_loss.score,
'test_uas': u_scorer.score,
'test_las': l_scorer.score}
# TODO: save these
bp.test_results = stats
for key, val in stats.items():
print('%s: %s' % (key, val))
if __name__ == "__main__":
# needed to import train to visualise_train
parser = ArgumentParser(description='Dependency parser evaluator')
parser.add_argument('--blueprint', required=True, type=str,
help='Path to .bp blueprint file produces by training.')
parser.add_argument('--test_file', required=True, type=str,
help='Conll file to use for testing')
parser.add_argument('--conll_out', action='store_true',
help='If specified writes conll output')
parser.add_argument('--treeify', type=str, default='chu',
help='algorithm to postprocess arcs with. '
'Choose chu to allow for non projectivity, else eisner')
args = parser.parse_args()
CONLL_OUT = args.conll_out
TREEIFY = args.treeify
blueprint = Blueprint.from_file(args.blueprint)
blueprint.model.treeify = TREEIFY
test_data = UDepLoader.load_conllu(args.test_file)
test_data.lang = blueprint.dataset.lang
test_loop(blueprint, test_data)
if CONLL_OUT:
test_data.save(blueprint.model_path.replace('.model', '.conllu'))