1
- """NVDM Tensorflow implementation -- Yishu Miao"""
1
+ """NVDM Tensorflow implementation by Yishu Miao"""
2
2
from __future__ import print_function
3
3
4
4
import numpy as np
12
12
13
13
flags = tf .app .flags
14
14
flags .DEFINE_string ('data_dir' , 'data/20news' , 'Data dir path.' )
15
- flags .DEFINE_float ('learning_rate' , 1e -5 , 'Learning rate.' )
15
+ flags .DEFINE_float ('learning_rate' , 5e -5 , 'Learning rate.' )
16
16
flags .DEFINE_integer ('batch_size' , 64 , 'Batch size.' )
17
17
flags .DEFINE_integer ('n_hidden' , 500 , 'Size of each hidden layer.' )
18
18
flags .DEFINE_integer ('n_topic' , 50 , 'Size of stochastic vector.' )
19
19
flags .DEFINE_integer ('n_sample' , 1 , 'Number of samples.' )
20
20
flags .DEFINE_integer ('vocab_size' , 2000 , 'Vocabulary size.' )
21
21
flags .DEFINE_boolean ('test' , False , 'Process test data.' )
22
+ flags .DEFINE_string ('non_linearity' , 'tanh' , 'Non-linearity of the MLP.' )
22
23
FLAGS = flags .FLAGS
23
24
24
25
class NVDM (object ):
25
26
""" Neural Variational Document Model -- BOW VAE.
26
27
"""
27
28
def __init__ (self ,
28
- vocab_size = 2000 ,
29
- n_hidden = 500 ,
30
- n_topic = 50 ,
31
- n_sample = 1 ,
32
- learning_rate = 1e-5 ,
33
- batch_size = 64 ,
34
- non_linearity = tf . nn . tanh ):
29
+ vocab_size ,
30
+ n_hidden ,
31
+ n_topic ,
32
+ n_sample ,
33
+ learning_rate ,
34
+ batch_size ,
35
+ non_linearity ):
35
36
self .vocab_size = vocab_size
36
37
self .n_hidden = n_hidden
37
38
self .n_topic = n_topic
@@ -45,12 +46,12 @@ def __init__(self,
45
46
46
47
# encoder
47
48
with tf .variable_scope ('encoder' ):
48
- self .enc_vec = utils .mlp (self .x , [self .n_hidden , self .n_hidden ] )
49
+ self .enc_vec = utils .mlp (self .x , [self .n_hidden ] , self .non_linearity )
49
50
self .mean = utils .linear (self .enc_vec , self .n_topic , scope = 'mean' )
50
51
self .logsigm = utils .linear (self .enc_vec ,
51
52
self .n_topic ,
52
53
bias_start_zero = True ,
53
- matrix_start_zero = False ,
54
+ matrix_start_zero = True ,
54
55
scope = 'logsigm' )
55
56
self .kld = - 0.5 * tf .reduce_sum (1 - tf .square (self .mean ) + 2 * self .logsigm - tf .exp (2 * self .logsigm ), 1 )
56
57
self .kld = self .mask * self .kld # mask paddings
@@ -95,14 +96,11 @@ def train(sess, model,
95
96
training_epochs = 1000 ,
96
97
alternate_epochs = 10 ):
97
98
"""train nvdm model."""
98
- data_set , data_count = utils .data_set (train_url )
99
+ train_set , train_count = utils .data_set (train_url )
99
100
test_set , test_count = utils .data_set (test_url )
100
101
# hold-out development dataset
101
- divide = int (0.9 * len (data_set ))
102
- train_set = data_set [:divide ]
103
- train_count = data_count [:divide ]
104
- dev_set = data_set [divide :]
105
- dev_count = data_count [divide :]
102
+ dev_set = test_set [:50 ]
103
+ dev_count = test_count [:50 ]
106
104
107
105
dev_batches = utils .create_batches (len (dev_set ), batch_size , shuffle = False )
108
106
test_batches = utils .create_batches (len (test_set ), batch_size , shuffle = False )
@@ -120,8 +118,10 @@ def train(sess, model,
120
118
print_mode = 'updating encoder'
121
119
for i in xrange (alternate_epochs ):
122
120
loss_sum = 0.0
121
+ ppx_sum = 0.0
123
122
kld_sum = 0.0
124
123
word_count = 0
124
+ doc_count = 0
125
125
for idx_batch in train_batches :
126
126
data_batch , count_batch , mask = utils .fetch_data (
127
127
train_set , train_count , idx_batch , FLAGS .vocab_size )
@@ -130,39 +130,55 @@ def train(sess, model,
130
130
[model .objective , model .kld ]),
131
131
input_feed )
132
132
loss_sum += np .sum (loss )
133
- kld_sum += np .sum (kld )/ np .sum (mask )
133
+ kld_sum += np .sum (kld ) / np .sum (mask )
134
134
word_count += np .sum (count_batch )
135
+ # to avoid nan error
136
+ count_batch = np .add (count_batch , 1e-12 )
137
+ # per document loss
138
+ ppx_sum += np .sum (np .divide (loss , count_batch ))
139
+ doc_count += np .sum (mask )
135
140
print_ppx = np .exp (loss_sum / word_count )
141
+ print_ppx_perdoc = np .exp (ppx_sum / doc_count )
136
142
print_kld = kld_sum / len (train_batches )
137
143
print ('| Epoch train: {:d} |' .format (epoch + 1 ),
138
144
print_mode , '{:d}' .format (i ),
139
- '| Perplexity: {:.5f}' .format (print_ppx ),
145
+ '| Corpus ppx: {:.5f}' .format (print_ppx ), # perplexity for all docs
146
+ '| Per doc ppx: {:.5f}' .format (print_ppx_perdoc ), # perplexity for per doc
140
147
'| KLD: {:.5}' .format (print_kld ))
141
148
#-------------------------------
142
149
# dev
143
150
loss_sum = 0.0
144
151
kld_sum = 0.0
152
+ ppx_sum = 0.0
145
153
word_count = 0
154
+ doc_count = 0
146
155
for idx_batch in dev_batches :
147
156
data_batch , count_batch , mask = utils .fetch_data (
148
157
dev_set , dev_count , idx_batch , FLAGS .vocab_size )
149
158
input_feed = {model .x .name : data_batch , model .mask .name : mask }
150
159
loss , kld = sess .run ([model .objective , model .kld ],
151
160
input_feed )
152
161
loss_sum += np .sum (loss )
153
- kld_sum += np .sum (kld )/ np .sum (mask )
162
+ kld_sum += np .sum (kld ) / np .sum (mask )
154
163
word_count += np .sum (count_batch )
164
+ count_batch = np .add (count_batch , 1e-12 )
165
+ ppx_sum += np .sum (np .divide (loss , count_batch ))
166
+ doc_count += np .sum (mask )
155
167
print_ppx = np .exp (loss_sum / word_count )
168
+ print_ppx_perdoc = np .exp (ppx_sum / doc_count )
156
169
print_kld = kld_sum / len (dev_batches )
157
170
print ('| Epoch dev: {:d} |' .format (epoch + 1 ),
158
171
'| Perplexity: {:.9f}' .format (print_ppx ),
172
+ '| Per doc ppx: {:.5f}' .format (print_ppx_perdoc ),
159
173
'| KLD: {:.5}' .format (print_kld ))
160
174
#-------------------------------
161
175
# test
162
176
if FLAGS .test :
163
177
loss_sum = 0.0
164
178
kld_sum = 0.0
179
+ ppx_sum = 0.0
165
180
word_count = 0
181
+ doc_count = 0
166
182
for idx_batch in test_batches :
167
183
data_batch , count_batch , mask = utils .fetch_data (
168
184
test_set , test_count , idx_batch , FLAGS .vocab_size )
@@ -172,20 +188,32 @@ def train(sess, model,
172
188
loss_sum += np .sum (loss )
173
189
kld_sum += np .sum (kld )/ np .sum (mask )
174
190
word_count += np .sum (count_batch )
191
+ count_batch = np .add (count_batch , 1e-12 )
192
+ ppx_sum += np .sum (np .divide (loss , count_batch ))
193
+ doc_count += np .sum (mask )
175
194
print_ppx = np .exp (loss_sum / word_count )
195
+ print_ppx_perdoc = np .exp (ppx_sum / doc_count )
176
196
print_kld = kld_sum / len (test_batches )
177
197
print ('| Epoch test: {:d} |' .format (epoch + 1 ),
178
198
'| Perplexity: {:.9f}' .format (print_ppx ),
199
+ '| Per doc ppx: {:.5f}' .format (print_ppx_perdoc ),
179
200
'| KLD: {:.5}' .format (print_kld ))
180
201
181
202
def main (argv = None ):
203
+ if FLAGS .non_linearity == 'tanh' :
204
+ non_linearity = tf .nn .tanh
205
+ elif FLAGS .non_linearity == 'sigmoid' :
206
+ non_linearity = tf .nn .sigmoid
207
+ else :
208
+ non_linearity = tf .nn .relu
209
+
182
210
nvdm = NVDM (vocab_size = FLAGS .vocab_size ,
183
211
n_hidden = FLAGS .n_hidden ,
184
212
n_topic = FLAGS .n_topic ,
185
213
n_sample = FLAGS .n_sample ,
186
214
learning_rate = FLAGS .learning_rate ,
187
215
batch_size = FLAGS .batch_size ,
188
- non_linearity = tf . nn . relu )
216
+ non_linearity = non_linearity )
189
217
sess = tf .Session ()
190
218
init = tf .initialize_all_variables ()
191
219
sess .run (init )
0 commit comments