Skip to content

Commit

Permalink
Add "soft" argument of "--print-alignment" (facebookresearch#2985)
Browse files Browse the repository at this point in the history
Summary:
If the argument is set to "soft", print probability for each source
token, like this:

A-0        0.365083,0.328207,0.306710 0.442428,0.340282,0.217290
0.378712,0.367315,0.253973 0.321335,0.425601,0.253064

Each source token is separated from each other by a comma (,) and each
target token is separated from each other by a space ( ).

This option is based on the Marian NMT's option.

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: facebookresearch#2985

Reviewed By: alexeib

Differential Revision: D25344394

Pulled By: myleott

fbshipit-source-id: 659eb8f7af1ccdafacaaa91ce5ddf5d71cb3e775
  • Loading branch information
Hiroyuki Deguchi authored and facebook-github-bot committed Dec 12, 2020
1 parent 39e722c commit 032a404
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
9 changes: 6 additions & 3 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES,
)

Expand Down Expand Up @@ -737,10 +738,12 @@ class GenerationConfig(FairseqDataclass):
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: bool = field(
default=False,
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field(
default=None,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens"
"help": "if set, uses attention feedback to compute and print alignment to source tokens "
"(valid options are: hard, soft, otherwise treated as hard alignment)",
"argparse_const": "hard",
},
)
print_step: bool = field(
Expand Down
1 change: 1 addition & 0 deletions fairseq/dataclass/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ def ChoiceEnum(choices: List[str]):
)
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"])
9 changes: 7 additions & 2 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def reorder_incremental_state(


class SequenceGeneratorWithAlignment(SequenceGenerator):
def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs):
def __init__(self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs):
"""Generates translations of a given source sentence.
Produces alignments following "Jointly Learning to Align and
Expand All @@ -917,6 +917,11 @@ def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs):
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
self.left_pad_target = left_pad_target

if print_alignment == "hard":
self.extract_alignment = utils.extract_hard_alignment
elif print_alignment == "soft":
self.extract_alignment = utils.extract_soft_alignment

@torch.no_grad()
def generate(self, models, sample, **kwargs):
finalized = super()._generate(sample, **kwargs)
Expand Down Expand Up @@ -945,7 +950,7 @@ def generate(self, models, sample, **kwargs):

# Process the attn matrix to extract hard alignments.
for i in range(bsz * beam_size):
alignment = utils.extract_hard_alignment(
alignment = self.extract_alignment(
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
)
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
Expand Down
4 changes: 3 additions & 1 deletion fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,14 @@ def build_generator(
else:
search_strategy = search.BeamSearch(self.target_dictionary)

extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
seq_gen_cls = SequenceGeneratorWithAlignment
extra_gen_cls_kwargs['print_alignment'] = args.print_alignment
else:
seq_gen_cls = SequenceGenerator
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}

return seq_gen_cls(
models,
self.target_dictionary,
Expand Down
17 changes: 17 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,23 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
return alignment


def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = (
((tgt_sent != pad)).nonzero(as_tuple=False)
)
src_valid = (
((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1)
)
alignment = []
if len(tgt_valid) != 0 and len(src_valid) != 0:
attn_valid = attn[tgt_valid, src_valid]
alignment = [
["{:.6f}".format(p) for p in src_probs.tolist()]
for src_probs in attn_valid
]
return alignment


def new_arange(x, *size):
"""
Return a Tensor of `size` filled with a range function on the device of x.
Expand Down
15 changes: 14 additions & 1 deletion fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def decode_fn(x):
file=output_file,
)

if cfg.generation.print_alignment:
if cfg.generation.print_alignment == "hard":
print(
"A-{}\t{}".format(
sample_id,
Expand All @@ -312,6 +312,19 @@ def decode_fn(x):
),
file=output_file,
)
if cfg.generation.print_alignment == "soft":
print(
"A-{}\t{}".format(
sample_id,
" ".join(
[
",".join(src_probs)
for src_probs in alignment
]
),
),
file=output_file,
)

if cfg.generation.print_step:
print(
Expand Down

0 comments on commit 032a404

Please sign in to comment.