@@ -25,6 +25,11 @@ def build_loss_compute(model, tgt_field, opt, train=True):
25
25
26
26
padding_idx = tgt_field .vocab .stoi [tgt_field .pad_token ]
27
27
unk_idx = tgt_field .vocab .stoi [tgt_field .unk_token ]
28
+
29
+ if opt .lambda_coverage != 0 :
30
+ assert opt .coverage_attn , "--coverage_attn needs to be set in " \
31
+ "order to use --lambda_coverage != 0"
32
+
28
33
if opt .copy_attn :
29
34
criterion = onmt .modules .CopyGeneratorLoss (
30
35
len (tgt_field .vocab ), opt .copy_attn_force ,
@@ -47,10 +52,12 @@ def build_loss_compute(model, tgt_field, opt, train=True):
47
52
loss_gen = model .generator [0 ] if use_raw_logits else model .generator
48
53
if opt .copy_attn :
49
54
compute = onmt .modules .CopyGeneratorLossCompute (
50
- criterion , loss_gen , tgt_field .vocab , opt .copy_loss_by_seqlength
55
+ criterion , loss_gen , tgt_field .vocab , opt .copy_loss_by_seqlength ,
56
+ lambda_coverage = opt .lambda_coverage
51
57
)
52
58
else :
53
- compute = NMTLossCompute (criterion , loss_gen )
59
+ compute = NMTLossCompute (
60
+ criterion , loss_gen , lambda_coverage = opt .lambda_coverage )
54
61
compute .to (device )
55
62
56
63
return compute
@@ -218,26 +225,53 @@ class NMTLossCompute(LossComputeBase):
218
225
Standard NMT Loss Computation.
219
226
"""
220
227
221
- def __init__ (self , criterion , generator , normalization = "sents" ):
228
+ def __init__ (self , criterion , generator , normalization = "sents" ,
229
+ lambda_coverage = 0.0 ):
222
230
super (NMTLossCompute , self ).__init__ (criterion , generator )
231
+ self .lambda_coverage = lambda_coverage
223
232
224
233
def _make_shard_state (self , batch , output , range_ , attns = None ):
225
- return {
234
+ shard_state = {
226
235
"output" : output ,
227
236
"target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ],
228
237
}
238
+ if self .lambda_coverage != 0.0 :
239
+ coverage = attns .get ("coverage" , None )
240
+ std = attns .get ("std" , None )
241
+ assert attns is not None
242
+ assert std is not None , "lambda_coverage != 0.0 requires " \
243
+ "attention mechanism"
244
+ assert coverage is not None , "lambda_coverage != 0.0 requires " \
245
+ "coverage attention"
246
+
247
+ shard_state .update ({
248
+ "std_attn" : attns .get ("std" ),
249
+ "coverage_attn" : coverage
250
+ })
251
+ return shard_state
252
+
253
+ def _compute_loss (self , batch , output , target , std_attn = None ,
254
+ coverage_attn = None ):
229
255
230
- def _compute_loss (self , batch , output , target ):
231
256
bottled_output = self ._bottle (output )
232
257
233
258
scores = self .generator (bottled_output )
234
259
gtruth = target .view (- 1 )
235
260
236
261
loss = self .criterion (scores , gtruth )
262
+ if self .lambda_coverage != 0.0 :
263
+ coverage_loss = self ._compute_coverage_loss (
264
+ std_attn = std_attn , coverage_attn = coverage_attn )
265
+ loss += coverage_loss
237
266
stats = self ._stats (loss .clone (), scores , gtruth )
238
267
239
268
return loss , stats
240
269
270
+ def _compute_coverage_loss (self , std_attn , coverage_attn ):
271
+ covloss = torch .min (std_attn , coverage_attn ).sum (2 ).view (- 1 )
272
+ covloss *= self .lambda_coverage
273
+ return covloss
274
+
241
275
242
276
def filter_shard_state (state , shard_size = None ):
243
277
for k , v in state .items ():
0 commit comments