Skip to content

Commit 7835130

Browse files
pltrdyvince62s
authored andcommitted
Implementing coverage loss of abisee (2017) (#1464)
* Implementing coverage loss of abisee (2017) * fix lambda_coverage value
1 parent 093491b commit 7835130

File tree

3 files changed

+60
-17
lines changed

3 files changed

+60
-17
lines changed

onmt/modules/copy_generator.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33

44
from onmt.utils.misc import aeq
5-
from onmt.utils.loss import LossComputeBase
5+
from onmt.utils.loss import NMTLossCompute
66

77

88
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None,
@@ -177,10 +177,12 @@ def forward(self, scores, align, target):
177177
return loss
178178

179179

180-
class CopyGeneratorLossCompute(LossComputeBase):
180+
class CopyGeneratorLossCompute(NMTLossCompute):
181181
"""Copy Generator Loss Computation."""
182-
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length):
183-
super(CopyGeneratorLossCompute, self).__init__(criterion, generator)
182+
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
183+
lambda_coverage=0.0):
184+
super(CopyGeneratorLossCompute, self).__init__(
185+
criterion, generator, lambda_coverage=lambda_coverage)
184186
self.tgt_vocab = tgt_vocab
185187
self.normalize_by_length = normalize_by_length
186188

@@ -190,14 +192,17 @@ def _make_shard_state(self, batch, output, range_, attns):
190192
raise AssertionError("using -copy_attn you need to pass in "
191193
"-dynamic_dict during preprocess stage.")
192194

193-
return {
194-
"output": output,
195-
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
195+
shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state(
196+
batch, output, range_, attns)
197+
198+
shard_state.update({
196199
"copy_attn": attns.get("copy"),
197200
"align": batch.alignment[range_[0] + 1: range_[1]]
198-
}
201+
})
202+
return shard_state
199203

200-
def _compute_loss(self, batch, output, target, copy_attn, align):
204+
def _compute_loss(self, batch, output, target, copy_attn, align,
205+
std_attn=None, coverage_attn=None):
201206
"""Compute the loss.
202207
203208
The args must match :func:`self._make_shard_state()`.
@@ -209,14 +214,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align):
209214
copy_attn: the copy attention value.
210215
align: the align info.
211216
"""
212-
213217
target = target.view(-1)
214218
align = align.view(-1)
215219
scores = self.generator(
216220
self._bottle(output), self._bottle(copy_attn), batch.src_map
217221
)
218222
loss = self.criterion(scores, align, target)
219223

224+
if self.lambda_coverage != 0.0:
225+
coverage_loss = self._compute_coverage_loss(std_attn,
226+
coverage_attn)
227+
loss += coverage_loss
228+
220229
# this block does not depend on the loss value computed above
221230
# and is used only for stats
222231
scores_data = collapse_copy_scores(

onmt/opts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def model_opts(parser):
176176
help="Divide copy loss by length of sequence")
177177
group.add('--coverage_attn', '-coverage_attn', action="store_true",
178178
help='Train a coverage attention layer.')
179-
group.add('--lambda_coverage', '-lambda_coverage', type=float, default=1,
180-
help='Lambda value for coverage.')
179+
group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0,
180+
help='Lambda value for coverage loss of See et al (2017)')
181181
group.add('--loss_scale', '-loss_scale', type=float, default=0,
182182
help="For FP16 training, the static loss scale to use. If not "
183183
"set, the loss scale is dynamically computed.")

onmt/utils/loss.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def build_loss_compute(model, tgt_field, opt, train=True):
2525

2626
padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token]
2727
unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token]
28+
29+
if opt.lambda_coverage != 0:
30+
assert opt.coverage_attn, "--coverage_attn needs to be set in " \
31+
"order to use --lambda_coverage != 0"
32+
2833
if opt.copy_attn:
2934
criterion = onmt.modules.CopyGeneratorLoss(
3035
len(tgt_field.vocab), opt.copy_attn_force,
@@ -47,10 +52,12 @@ def build_loss_compute(model, tgt_field, opt, train=True):
4752
loss_gen = model.generator[0] if use_raw_logits else model.generator
4853
if opt.copy_attn:
4954
compute = onmt.modules.CopyGeneratorLossCompute(
50-
criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength
55+
criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength,
56+
lambda_coverage=opt.lambda_coverage
5157
)
5258
else:
53-
compute = NMTLossCompute(criterion, loss_gen)
59+
compute = NMTLossCompute(
60+
criterion, loss_gen, lambda_coverage=opt.lambda_coverage)
5461
compute.to(device)
5562

5663
return compute
@@ -218,26 +225,53 @@ class NMTLossCompute(LossComputeBase):
218225
Standard NMT Loss Computation.
219226
"""
220227

221-
def __init__(self, criterion, generator, normalization="sents"):
228+
def __init__(self, criterion, generator, normalization="sents",
229+
lambda_coverage=0.0):
222230
super(NMTLossCompute, self).__init__(criterion, generator)
231+
self.lambda_coverage = lambda_coverage
223232

224233
def _make_shard_state(self, batch, output, range_, attns=None):
225-
return {
234+
shard_state = {
226235
"output": output,
227236
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
228237
}
238+
if self.lambda_coverage != 0.0:
239+
coverage = attns.get("coverage", None)
240+
std = attns.get("std", None)
241+
assert attns is not None
242+
assert std is not None, "lambda_coverage != 0.0 requires " \
243+
"attention mechanism"
244+
assert coverage is not None, "lambda_coverage != 0.0 requires " \
245+
"coverage attention"
246+
247+
shard_state.update({
248+
"std_attn": attns.get("std"),
249+
"coverage_attn": coverage
250+
})
251+
return shard_state
252+
253+
def _compute_loss(self, batch, output, target, std_attn=None,
254+
coverage_attn=None):
229255

230-
def _compute_loss(self, batch, output, target):
231256
bottled_output = self._bottle(output)
232257

233258
scores = self.generator(bottled_output)
234259
gtruth = target.view(-1)
235260

236261
loss = self.criterion(scores, gtruth)
262+
if self.lambda_coverage != 0.0:
263+
coverage_loss = self._compute_coverage_loss(
264+
std_attn=std_attn, coverage_attn=coverage_attn)
265+
loss += coverage_loss
237266
stats = self._stats(loss.clone(), scores, gtruth)
238267

239268
return loss, stats
240269

270+
def _compute_coverage_loss(self, std_attn, coverage_attn):
271+
covloss = torch.min(std_attn, coverage_attn).sum(2).view(-1)
272+
covloss *= self.lambda_coverage
273+
return covloss
274+
241275

242276
def filter_shard_state(state, shard_size=None):
243277
for k, v in state.items():

0 commit comments

Comments
 (0)