Skip to content

Commit c4f8dfa

Browse files
committed
Modify the tools to support an extra constant, D4, for discounting parts of counts larger than 3.
1 parent ec5f881 commit c4f8dfa

12 files changed

+154
-133
lines changed

egs/swbd/run.sh

+4-6
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ get_initial_metaparameters.py \
4747
optimize_metaparameters.py --gradient-tolerance=0.005 \
4848
data/counts_20k_3 data/optimize_20k_3
4949

50-
# optimize_metaparameters.py: log-prob on dev data increased from -4.42783439035
51-
# to -4.41743853837 over 9 passes of derivative estimation (perplexity:
52-
# 83.7498508954->82.8837097801
50+
# log-prob on dev data increased from -4.42278966972 to -4.41127767165 over 6 passes of derivative estimation (perplexity: 83.3284201887->82.3746440442)
5351

5452

5553
get_counts.sh data/int_20k 4 data/counts_20k_4
@@ -63,6 +61,6 @@ get_initial_metaparameters.py \
6361
optimize_metaparameters.py --gradient-tolerance=0.005 \
6462
data/counts_20k_4 data/optimize_20k_4
6563

66-
# optimize_metaparameters.py: log-prob on dev data increased from -4.42864701686
67-
# to -4.38964142483 over 13 passes of derivative estimation (perplexity:
68-
# 83.8179359045->80.6115085121
64+
# optimize_metaparameters.py: log-prob on dev data increased from -4.4224930661
65+
# to -4.38089709795 over 13 passes of derivative estimation (perplexity:
66+
# 83.3037083426->79.909688077

scripts/get_initial_metaparameters.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,5 @@ def ReadWeights(weights_file):
145145
print("order{0}_D1 0.8".format(o))
146146
print("order{0}_D2 0.4".format(o))
147147
print("order{0}_D3 0.2".format(o))
148+
print("order{0}_D4 0.1".format(o))
148149

scripts/get_objf_and_derivs.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@
6161
train_set_scale = {}
6262
for n in range(1, num_train_sets + 1):
6363
train_set_scale[n] = float(f.readline().split()[1])
64-
# the discounting constants will be stored as maps d1,d2,d3 from integer order
64+
# the discounting constants will be stored as maps d1,d2,d3,d4 from integer order
6565
# to discounting constant.
6666
d1 = {}
6767
d2 = {}
6868
d3 = {}
69+
d4 = {}
6970
for o in range(2, ngram_order + 1):
7071
d1[o] = float(f.readline().split()[1])
7172
d2[o] = float(f.readline().split()[1])
7273
d3[o] = float(f.readline().split()[1])
74+
d4[o] = float(f.readline().split()[1])
7375
f.close()
7476

7577

@@ -149,27 +151,28 @@ def MergeCountsBackward(order):
149151
def DiscountCounts(order):
150152
# discount counts of the specified order > 1.
151153
assert order > 1
152-
command = "discount-counts {d1} {d2} {d3} {work}/merged.{order} {work}/float.{order} {work}/discounted.{orderm1} ".format(
153-
d1 = d1[order], d2 = d2[order], d3 = d3[order], work = args.work_dir,
154-
order = order, orderm1 = order - 1)
154+
command = "discount-counts {d1} {d2} {d3} {d4} {work}/merged.{order} {work}/float.{order} {work}/discounted.{orderm1} ".format(
155+
d1 = d1[order], d2 = d2[order], d3 = d3[order], d4 = d4[order],
156+
work = args.work_dir, order = order, orderm1 = order - 1)
155157
RunCommand(command)
156158

157159
def DiscountCountsBackward(order):
158160
# discount counts of the specified order > 1; backprop version.
159161
assert order > 1
160-
command = ("discount-counts-backward {d1} {d2} {d3} {work}/merged.{order} {work}/float.{order} "
162+
command = ("discount-counts-backward {d1} {d2} {d3} {d4} {work}/merged.{order} {work}/float.{order} "
161163
"{work}/float_derivs.{order} {work}/discounted.{orderm1} {work}/discounted_derivs.{orderm1} "
162164
"{work}/merged_derivs.{order}".format(
163-
d1 = d1[order], d2 = d2[order], d3 = d3[order], work = args.work_dir,
164-
order = order, orderm1 = order - 1))
165+
d1 = d1[order], d2 = d2[order], d3 = d3[order], d4 = d4[order],
166+
work = args.work_dir, order = order, orderm1 = order - 1))
165167
output = GetCommandStdout(command);
166168
try:
167-
[ deriv1, deriv2, deriv3 ] = output.split()
169+
[ deriv1, deriv2, deriv3, deriv4 ] = output.split()
168170
except:
169171
sys.exit("get_objf_and_derivs.py: could not parse output of command: " + output)
170172
d1_deriv[order] = float(deriv1) / num_dev_set_words
171173
d2_deriv[order] = float(deriv2) / num_dev_set_words
172174
d3_deriv[order] = float(deriv3) / num_dev_set_words
175+
d4_deriv[order] = float(deriv4) / num_dev_set_words
173176

174177

175178
def DiscountCountsOrder1():
@@ -226,6 +229,7 @@ def WriteDerivs():
226229
print("order{0}_D1 {1}".format(o, d1_deriv[o]), file=f)
227230
print("order{0}_D2 {1}".format(o, d2_deriv[o]), file=f)
228231
print("order{0}_D3 {1}".format(o, d3_deriv[o]), file=f)
232+
print("order{0}_D4 {1}".format(o, d4_deriv[o]), file=f)
229233
f.close()
230234

231235
# for n-gram orders down to 2, do the merging and discounting.
@@ -247,6 +251,7 @@ def WriteDerivs():
247251
d1_deriv = {}
248252
d2_deriv = {}
249253
d3_deriv = {}
254+
d4_deriv = {}
250255

251256
# Now comes the backprop code.
252257

scripts/optimize_metaparameters.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,17 @@ def WriteMetaparameters(file, array):
9999
# d3 > 0. Otherwise it returns false.
100100
def MetaparametersAreAllowed(x):
101101
global num_train_sets, ngram_order
102-
assert len(x) == num_train_sets + 3 * (ngram_order - 1)
102+
assert len(x) == num_train_sets + 4 * (ngram_order - 1)
103103
for i in range(num_train_sets):
104104
if x[i] <= 0.0 or x[i] >= 1.0:
105105
return False
106106
for o in range(2, ngram_order + 1):
107-
dim_offset = num_train_sets + 3 * (o-2)
107+
dim_offset = num_train_sets + 4 * (o-2)
108108
d1 = x[dim_offset]
109109
d2 = x[dim_offset + 1]
110110
d3 = x[dim_offset + 2]
111-
if not (1.0 > d1 and d1 > d2 and d2 > d3 and d3 > 0.0):
111+
d4 = x[dim_offset + 3]
112+
if not (1.0 > d1 and d1 > d2 and d2 > d3 and d3 > d4 and d4 > 0.0):
112113
return False
113114
return True
114115

@@ -124,30 +125,35 @@ def ModifyWithBarrierFunction(x, objf, derivs):
124125
epsilon = args.barrier_epsilon
125126
derivs = derivs.copy() # don't overwrite the object.
126127
global num_train_sets, ngram_order
127-
assert len(x) == num_train_sets + 3 * (ngram_order - 1)
128+
assert len(x) == num_train_sets + 4 * (ngram_order - 1)
128129
for i in range(num_train_sets):
129130
xi = x[i]
130131
# the constraints are: xi > 0.0, and 1.0 - xi > 0.0
131132
objf += epsilon * (log(xi - 0.0) + log(1.0 - xi))
132133
derivs[i] += epsilon * ((1.0 / xi) + (-1.0 / (1.0 - xi)))
133134

134135
for o in range(2, ngram_order + 1):
135-
dim_offset = num_train_sets + 3 * (o-2)
136+
dim_offset = num_train_sets + 4 * (o-2)
136137
d1 = x[dim_offset]
137138
d2 = x[dim_offset + 1]
138139
d3 = x[dim_offset + 2]
140+
d4 = x[dim_offset + 3]
139141
# the constraints are:
140142
# 1.0 - d1 > 0.0
141143
# d1 - d2 > 0.0
142144
# d2 - d3 > 0.0
143-
# d3 > 0.0
144-
objf += epsilon * (log(1.0 - d1) + log(d1 - d2) + log(d2 - d3) + log(d3))
145+
# d3 - d4 > 0.0
146+
# d4 > 0.0
147+
objf += epsilon * (log(1.0 - d1) + log(d1 - d2) + log(d2 - d3) +
148+
log(d3 - d4) + log(d4))
145149
# deriv for d1
146150
derivs[dim_offset] += epsilon * (-1.0 / (1.0 - d1) + 1.0 / (d1 - d2))
147151
# deriv for d2
148152
derivs[dim_offset + 1] += epsilon * (-1.0 / (d1 - d2) + 1.0 / (d2 - d3))
149153
# deriv for d3
150-
derivs[dim_offset + 2] += epsilon * (-1.0 / (d2 - d3) + 1.0 / d3)
154+
derivs[dim_offset + 2] += epsilon * (-1.0 / (d2 - d3) + 1.0 / (d3 - d4))
155+
# deriv for d4
156+
derivs[dim_offset + 3] += epsilon * (-1.0 / (d3 - d4) + 1.0 / d4)
151157
return (objf, derivs)
152158

153159

scripts/validate_metaparameter_derivs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
deriv_line[0:-1]))
6464

6565
for o in range(2, args.ngram_order + 1):
66-
for n in range(3):
66+
for n in range(4):
6767
line = f.readline()
6868
deriv_line = deriv_f.readline()
6969
try:

scripts/validate_metaparameters.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,22 @@
5151
args.metaparameter_file))
5252

5353
for o in range(2, args.ngram_order + 1):
54-
line1 = f.readline()
55-
line2 = f.readline()
56-
line3 = f.readline()
54+
lines = []
55+
values = []
56+
for n in range(4):
57+
lines.append(f.readline())
5758
try:
58-
[ name1, value1 ] = line1.split()
59-
[ name2, value2 ] = line2.split()
60-
[ name3, value3 ] = line3.split()
61-
value1 = float(value1)
62-
value2 = float(value2)
63-
value3 = float(value3)
64-
assert name1 == "order{0}_D1".format(o)
65-
assert name2 == "order{0}_D2".format(o)
66-
assert name3 == "order{0}_D3".format(o)
67-
assert 1.0 > value1 and value1 > value2 and value2 > value3 and value3 > 0.0
68-
except:
59+
for n in range(4):
60+
[ name, value ] = lines[n].split()
61+
assert name == "order{0}_D{1}".format(o, n + 1)
62+
value = float(value)
63+
values.append(value)
64+
assert 1.0 > value and value > 0.0 and (n == 0 or value < values[n-1])
65+
except Exception as e:
6966
sys.exit("validate_metaparameters.py: bad values for {0}'th order "
70-
"n-gram discounting parameters: '{1}', '{2}', '{3}',"
71-
" in file {4}".format(o, line1[0:-1], line2[0:-1], line3[0:-1],
72-
args.metaparameter_file))
67+
"n-gram discounting parameters: in file {1}: {2}".format(
68+
o, args.metaparameter_file, str(e)))
7369

7470
if f.readline() != '':
7571
sys.exit("validate_metaparameters.py: junk at end of "
7672
"metaparameters file {0}".format(args.metaparameter_file))
77-

src/count.h

+3-5
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ namespace pocolm {
2828

2929
/**
3030
This class is used to store a special type of count that we use in estimating
31-
these language models. You can think of it as a type of 'extended' float that
32-
stores the sum of a bunch of individual small counts or parts of counts.
31+
these language models. You can think of it as a type of 'extended' float
32+
that stores the sum of a bunch of individual small counts or parts of counts.
3333
In addition to storing the total count, it also stores the top-1 "part"
34-
(i.e. the largest of the component parts), and also the runners up, which
34+
(i.e. the largest of the component parts), and also the two runners up, which
3535
we call top-2 and top-3.
36-
37-
3836
*/
3937
class Count {
4038
public:

0 commit comments

Comments
 (0)