Skip to content

Commit 19bc3d6

Browse files
author
ysmiao
committed
add per doc ppx
1 parent 6330256 commit 19bc3d6

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

nvdm.py

+49-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""NVDM Tensorflow implementation --Yishu Miao"""
1+
"""NVDM Tensorflow implementation by Yishu Miao"""
22
from __future__ import print_function
33

44
import numpy as np
@@ -12,26 +12,27 @@
1212

1313
flags = tf.app.flags
1414
flags.DEFINE_string('data_dir', 'data/20news', 'Data dir path.')
15-
flags.DEFINE_float('learning_rate', 1e-5, 'Learning rate.')
15+
flags.DEFINE_float('learning_rate', 5e-5, 'Learning rate.')
1616
flags.DEFINE_integer('batch_size', 64, 'Batch size.')
1717
flags.DEFINE_integer('n_hidden', 500, 'Size of each hidden layer.')
1818
flags.DEFINE_integer('n_topic', 50, 'Size of stochastic vector.')
1919
flags.DEFINE_integer('n_sample', 1, 'Number of samples.')
2020
flags.DEFINE_integer('vocab_size', 2000, 'Vocabulary size.')
2121
flags.DEFINE_boolean('test', False, 'Process test data.')
22+
flags.DEFINE_string('non_linearity', 'tanh', 'Non-linearity of the MLP.')
2223
FLAGS = flags.FLAGS
2324

2425
class NVDM(object):
2526
""" Neural Variational Document Model -- BOW VAE.
2627
"""
2728
def __init__(self,
28-
vocab_size=2000,
29-
n_hidden=500,
30-
n_topic=50,
31-
n_sample=1,
32-
learning_rate=1e-5,
33-
batch_size=64,
34-
non_linearity=tf.nn.tanh):
29+
vocab_size,
30+
n_hidden,
31+
n_topic,
32+
n_sample,
33+
learning_rate,
34+
batch_size,
35+
non_linearity):
3536
self.vocab_size = vocab_size
3637
self.n_hidden = n_hidden
3738
self.n_topic = n_topic
@@ -45,12 +46,12 @@ def __init__(self,
4546

4647
# encoder
4748
with tf.variable_scope('encoder'):
48-
self.enc_vec = utils.mlp(self.x, [self.n_hidden, self.n_hidden])
49+
self.enc_vec = utils.mlp(self.x, [self.n_hidden], self.non_linearity)
4950
self.mean = utils.linear(self.enc_vec, self.n_topic, scope='mean')
5051
self.logsigm = utils.linear(self.enc_vec,
5152
self.n_topic,
5253
bias_start_zero=True,
53-
matrix_start_zero=False,
54+
matrix_start_zero=True,
5455
scope='logsigm')
5556
self.kld = -0.5 * tf.reduce_sum(1 - tf.square(self.mean) + 2 * self.logsigm - tf.exp(2 * self.logsigm), 1)
5657
self.kld = self.mask*self.kld # mask paddings
@@ -95,14 +96,11 @@ def train(sess, model,
9596
training_epochs=1000,
9697
alternate_epochs=10):
9798
"""train nvdm model."""
98-
data_set, data_count = utils.data_set(train_url)
99+
train_set, train_count = utils.data_set(train_url)
99100
test_set, test_count = utils.data_set(test_url)
100101
# hold-out development dataset
101-
divide = int(0.9*len(data_set))
102-
train_set = data_set[:divide]
103-
train_count = data_count[:divide]
104-
dev_set = data_set[divide:]
105-
dev_count = data_count[divide:]
102+
dev_set = test_set[:50]
103+
dev_count = test_count[:50]
106104

107105
dev_batches = utils.create_batches(len(dev_set), batch_size, shuffle=False)
108106
test_batches = utils.create_batches(len(test_set), batch_size, shuffle=False)
@@ -120,8 +118,10 @@ def train(sess, model,
120118
print_mode = 'updating encoder'
121119
for i in xrange(alternate_epochs):
122120
loss_sum = 0.0
121+
ppx_sum = 0.0
123122
kld_sum = 0.0
124123
word_count = 0
124+
doc_count = 0
125125
for idx_batch in train_batches:
126126
data_batch, count_batch, mask = utils.fetch_data(
127127
train_set, train_count, idx_batch, FLAGS.vocab_size)
@@ -130,39 +130,55 @@ def train(sess, model,
130130
[model.objective, model.kld]),
131131
input_feed)
132132
loss_sum += np.sum(loss)
133-
kld_sum += np.sum(kld)/np.sum(mask)
133+
kld_sum += np.sum(kld) / np.sum(mask)
134134
word_count += np.sum(count_batch)
135+
# to avoid nan error
136+
count_batch = np.add(count_batch, 1e-12)
137+
# per document loss
138+
ppx_sum += np.sum(np.divide(loss, count_batch))
139+
doc_count += np.sum(mask)
135140
print_ppx = np.exp(loss_sum / word_count)
141+
print_ppx_perdoc = np.exp(ppx_sum / doc_count)
136142
print_kld = kld_sum/len(train_batches)
137143
print('| Epoch train: {:d} |'.format(epoch+1),
138144
print_mode, '{:d}'.format(i),
139-
'| Perplexity: {:.5f}'.format(print_ppx),
145+
'| Corpus ppx: {:.5f}'.format(print_ppx), # perplexity for all docs
146+
'| Per doc ppx: {:.5f}'.format(print_ppx_perdoc), # perplexity for per doc
140147
'| KLD: {:.5}'.format(print_kld))
141148
#-------------------------------
142149
# dev
143150
loss_sum = 0.0
144151
kld_sum = 0.0
152+
ppx_sum = 0.0
145153
word_count = 0
154+
doc_count = 0
146155
for idx_batch in dev_batches:
147156
data_batch, count_batch, mask = utils.fetch_data(
148157
dev_set, dev_count, idx_batch, FLAGS.vocab_size)
149158
input_feed = {model.x.name: data_batch, model.mask.name: mask}
150159
loss, kld = sess.run([model.objective, model.kld],
151160
input_feed)
152161
loss_sum += np.sum(loss)
153-
kld_sum += np.sum(kld)/np.sum(mask)
162+
kld_sum += np.sum(kld) / np.sum(mask)
154163
word_count += np.sum(count_batch)
164+
count_batch = np.add(count_batch, 1e-12)
165+
ppx_sum += np.sum(np.divide(loss, count_batch))
166+
doc_count += np.sum(mask)
155167
print_ppx = np.exp(loss_sum / word_count)
168+
print_ppx_perdoc = np.exp(ppx_sum / doc_count)
156169
print_kld = kld_sum/len(dev_batches)
157170
print('| Epoch dev: {:d} |'.format(epoch+1),
158171
'| Perplexity: {:.9f}'.format(print_ppx),
172+
'| Per doc ppx: {:.5f}'.format(print_ppx_perdoc),
159173
'| KLD: {:.5}'.format(print_kld))
160174
#-------------------------------
161175
# test
162176
if FLAGS.test:
163177
loss_sum = 0.0
164178
kld_sum = 0.0
179+
ppx_sum = 0.0
165180
word_count = 0
181+
doc_count = 0
166182
for idx_batch in test_batches:
167183
data_batch, count_batch, mask = utils.fetch_data(
168184
test_set, test_count, idx_batch, FLAGS.vocab_size)
@@ -172,20 +188,32 @@ def train(sess, model,
172188
loss_sum += np.sum(loss)
173189
kld_sum += np.sum(kld)/np.sum(mask)
174190
word_count += np.sum(count_batch)
191+
count_batch = np.add(count_batch, 1e-12)
192+
ppx_sum += np.sum(np.divide(loss, count_batch))
193+
doc_count += np.sum(mask)
175194
print_ppx = np.exp(loss_sum / word_count)
195+
print_ppx_perdoc = np.exp(ppx_sum / doc_count)
176196
print_kld = kld_sum/len(test_batches)
177197
print('| Epoch test: {:d} |'.format(epoch+1),
178198
'| Perplexity: {:.9f}'.format(print_ppx),
199+
'| Per doc ppx: {:.5f}'.format(print_ppx_perdoc),
179200
'| KLD: {:.5}'.format(print_kld))
180201

181202
def main(argv=None):
203+
if FLAGS.non_linearity == 'tanh':
204+
non_linearity = tf.nn.tanh
205+
elif FLAGS.non_linearity == 'sigmoid':
206+
non_linearity = tf.nn.sigmoid
207+
else:
208+
non_linearity = tf.nn.relu
209+
182210
nvdm = NVDM(vocab_size=FLAGS.vocab_size,
183211
n_hidden=FLAGS.n_hidden,
184212
n_topic=FLAGS.n_topic,
185213
n_sample=FLAGS.n_sample,
186214
learning_rate=FLAGS.learning_rate,
187215
batch_size=FLAGS.batch_size,
188-
non_linearity=tf.nn.relu)
216+
non_linearity=non_linearity)
189217
sess = tf.Session()
190218
init = tf.initialize_all_variables()
191219
sess.run(init)

utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def fetch_data(data, count, idx_batch, vocab_size):
5555
data_batch[i, word_id] = freq
5656
count_batch.append(count[doc_id])
5757
mask[i]=1.0
58+
else:
59+
count_batch.append(0)
5860
return data_batch, count_batch, mask
5961

6062
def variable_parser(var_list, prefix):

0 commit comments

Comments
 (0)