-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
73 lines (64 loc) · 2.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
__author__="Juliana Louback <[email protected]>"
import sys
import json
import fileinput
import logging
from collections import defaultdict
"""
Function to get unitary rule counts, binary rule counts and
nonterminal counts, organized in dictionaries.
Used in cyk.py
"""
def get_counts(filename):
nonterminal_count = dict()
unary_count = dict()
binary_count = dict()
# Obtain the count(X->YZ), count(X->w), count(X)
# rather, the binary and unary rule counts, and the
# non terminal counts
train_count = file(filename,"r")
line = train_count.readline()
while line:
parts = line.strip().split(" ")
line_type = parts[1]
count = parts[0]
# Get nonterminal counts = count(X)
if "TERMINAL" in line_type:
nonterminal = parts[2]
nonterminal_count[nonterminal] = count
# Get unary rule count = count(X->w)
if "UNARY" in line_type:
nonterminal = parts[2]
word = parts[3]
if nonterminal in unary_count:
unary_count[nonterminal].update({word:count})
else:
unary_count[nonterminal] = {word:count}
# Get binary rule count = count(x->YZ)
if "BINARY" in line_type:
x = parts[2]
y = parts[3]
z = parts[4]
key = y + " " + z
if x in binary_count:
binary_count[x].update({key:count})
else:
binary_count[x] = {key:count}
line = train_count.readline()
return nonterminal_count, unary_count, binary_count
# Returns true if word is not seen in training data
# (ergo not in unary rules)
def is_rare(word, unary_count):
for item in unary_count:
for words in unary_count[item]:
if word == words:
return False
return True
# Replaces words unseen in training data with _RARE_
def replace_rare(line, unary_count):
parts = line.split(" ")
replaced = line
for word in parts:
if is_rare(word, unary_count):
replaced = replaced.replace(word, "_RARE_")
return replaced