-
Notifications
You must be signed in to change notification settings - Fork 225
/
Copy pathanalyze_chunker_coverage.py
executable file
·120 lines (92 loc) · 3.85 KB
/
analyze_chunker_coverage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
import argparse, collections, math
import nltk.corpus, nltk.corpus.reader, nltk.data, nltk.tag, nltk.metrics
from nltk.corpus.util import LazyCorpusLoader
from nltk_trainer import load_corpus_reader, load_model, simplify_wsj_tag
from nltk_trainer.chunking import chunkers
from nltk_trainer.chunking.transforms import node_label
from nltk_trainer.tagging import taggers
########################################
## command options & argument parsing ##
########################################
parser = argparse.ArgumentParser(description='Analyze a part-of-speech tagged corpus',
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('corpus',
help='''The name of a tagged corpus included with NLTK, such as treebank,
brown, cess_esp, floresta, or the root path to a corpus directory,
which can be either an absolute path or relative to a nltk_data directory.''')
parser.add_argument('--tagger', default=None,
help='''pickled tagger filename/path relative to an nltk_data directory
default is NLTK's default tagger''')
parser.add_argument('--chunker', default=nltk.chunk._MULTICLASS_NE_CHUNKER,
help='''pickled chunker filename/path relative to an nltk_data directory
default is NLTK's default multiclass chunker''')
parser.add_argument('--trace', default=1, type=int,
help='How much trace output you want, defaults to 1. 0 is no trace output.')
parser.add_argument('--score', action='store_true', default=False,
help='Evaluate chunk score of chunker using corpus.chunked_sents()')
corpus_group = parser.add_argument_group('Corpus Reader Options')
corpus_group.add_argument('--reader', default=None,
help='''Full module path to a corpus reader class, such as
nltk.corpus.reader.chunked.ChunkedCorpusReader''')
corpus_group.add_argument('--fileids', default=None,
help='Specify fileids to load from corpus')
corpus_group.add_argument('--fraction', default=1.0, type=float,
help='''The fraction of the corpus to use for testing coverage''')
if simplify_wsj_tag:
corpus_group.add_argument('--simplify_tags', action='store_true', default=False,
help='Use simplified tags')
args = parser.parse_args()
###################
## corpus reader ##
###################
corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids)
if args.score and not hasattr(corpus, 'chunked_sents'):
raise ValueError('%s does not support scoring' % args.corpus)
############
## tagger ##
############
if args.trace:
print('loading tagger %s' % args.tagger)
if not args.tagger:
tagger = nltk.tag._get_tagger()
elif args.tagger == 'pattern':
tagger = taggers.PatternTagger()
else:
tagger = load_model(args.tagger)
if args.trace:
print('loading chunker %s' % args.chunker)
if args.chunker == 'pattern':
chunker = chunkers.PatternChunker()
else:
chunker = load_model(args.chunker)
#######################
## coverage analysis ##
#######################
if args.score:
if args.trace:
print('evaluating chunker score\n')
chunked_sents = corpus.chunked_sents()
if args.fraction != 1.0:
cutoff = int(math.ceil(len(chunked_sents) * args.fraction))
chunked_sents = chunked_sents[:cutoff]
print(chunker.evaluate(chunked_sents))
print('\n')
if args.trace:
print('analyzing chunker coverage of %s with %s\n' % (args.corpus, chunker.__class__.__name__))
iobs_found = collections.defaultdict(int)
sents = corpus.sents()
if args.fraction != 1.0:
cutoff = int(math.ceil(len(sents) * args.fraction))
sents = sents[:cutoff]
for sent in sents:
tree = chunker.parse(tagger.tag(sent))
for child in tree.subtrees(lambda t: node_label(t) != 'S'):
iobs_found[node_label(child)] += 1
iobs = iobs_found.keys()
justify = max(7, *[len(iob) for iob in iobs])
print('IOB'.center(justify) + ' Found ')
print('='*justify + ' =========')
for iob in sorted(iobs):
print(' '.join([iob.ljust(justify), str(iobs_found[iob]).rjust(9)]))
print('='*justify + ' =========')