diff --git a/CHANGELOG.md b/CHANGELOG.md index a3e112741..f578e50d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.15] +### Added +- Added decoding with target-side lexical constraints (documentation in `tutorials/constraints`). + ## [1.18.14] ### Added - Introduced Sockeye Autopilot for single-command end-to-end system building. diff --git a/docs/modules.rst b/docs/modules.rst index 2355af7df..ec193bc34 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -106,6 +106,13 @@ sockeye.layers module :members: :show-inheritance: +sockeye.lexical_constraints module +---------------------------------- + +.. automodule:: sockeye.lexical_constraints + :members: + :show-inheritance: + sockeye.lexicon module ---------------------- diff --git a/sockeye/__init__.py b/sockeye/__init__.py index c5d313de6..5cb878d74 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.14' +__version__ = '1.18.15' diff --git a/sockeye/constants.py b/sockeye/constants.py index f6aba22fd..20de9cd50 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -192,8 +192,12 @@ # Inference Input JSON constants JSON_TEXT_KEY = "text" JSON_FACTORS_KEY = "factors" +JSON_CONSTRAINTS_KEY = "constraints" JSON_ENCODING = "utf-8" +# Lexical constraints +BANK_ADJUSTMENT = 'even' + VERSION_NAME = "version" CONFIG_NAME = "config" LOG_NAME = "log" diff --git a/sockeye/inference.py b/sockeye/inference.py index c184a061f..197a4d7cd 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -14,11 +14,15 @@ """ Code for inference/translation """ +import copy import itertools import json import logging +import math import os +import sys import time + from collections import defaultdict from functools import lru_cache, partial from typing import Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, Set @@ -34,6 +38,8 @@ from . import vocab from .log import is_python34 +from . import lexical_constraints as constrained + logger = logging.getLogger(__name__) @@ -292,6 +298,7 @@ def run_encoder(self, self.encoder_module.forward(data_batch=batch, is_train=False) decoder_states = self.encoder_module.get_outputs() + # replicate encoder/init module results beam size times decoder_states = [mx.nd.repeat(s, repeats=self.beam_size, axis=0) for s in decoder_states] return ModelState(decoder_states) @@ -558,23 +565,26 @@ class TranslatorInput: :param sentence_id: Sentence id. :param tokens: List of input tokens. :param factors: Optional list of additional factor sequences. + :param constraints: Optional list of target-side constraints. :param chunk_id: Chunk id. Defaults to -1. """ - __slots__ = ('sentence_id', 'tokens', 'factors', 'chunk_id') + __slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'chunk_id') def __init__(self, sentence_id: int, tokens: Tokens, factors: Optional[List[Tokens]] = None, + constraints: Optional[List[Tokens]] = None, chunk_id: int = -1) -> None: self.sentence_id = sentence_id self.chunk_id = chunk_id self.tokens = tokens self.factors = factors + self.constraints = constraints def __str__(self): - return 'TranslatorInput(%d, %s, %s, %d)' % (self.sentence_id, self.tokens, self.factors, self.chunk_id) + return 'TranslatorInput(%d, %s, factors=%s, constraints=%s, chunk_id=%d)' % (self.sentence_id, self.tokens, self.factors, self.constraints, self.chunk_id) def __len__(self): return len(self.tokens) @@ -593,22 +603,34 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: :param chunk_size: The maximum size of a chunk. :return: A generator of TranslatorInputs, one for each chunk created. """ + + if len(self.tokens) > chunk_size and self.constraints is not None: + logger.warning( + 'Input %d has length (%d) that exceeds max input length (%d), ' + 'triggering internal splitting. Placing all target-side constraints ' + 'with the first chunk, which is probably wrong.', + self.sentence_id, len(self.tokens), chunk_size) + for chunk_id, i in enumerate(range(0, len(self), chunk_size)): factors = [factor[i:i + chunk_size] for factor in self.factors] if self.factors is not None else None + # Constrained decoding is not supported for chunked TranslatorInputs. As a fall-back, constraints are + # assigned to the first chunk + constraints = self.constraints if chunk_id == 0 else None yield TranslatorInput(sentence_id=self.sentence_id, tokens=self.tokens[i:i + chunk_size], factors=factors, + constraints=constraints, chunk_id=chunk_id) def with_eos(self) -> 'TranslatorInput': """ :return: A new translator input with EOS appended to the tokens and factors. """ - return TranslatorInput(self.sentence_id, - self.tokens + [C.EOS_SYMBOL], - [factor + [C.EOS_SYMBOL] for factor in self.factors] - if self.factors is not None else None, - self.chunk_id) + return TranslatorInput(sentence_id=self.sentence_id, + tokens=self.tokens + [C.EOS_SYMBOL], + factors=[factor + [C.EOS_SYMBOL] for factor in self.factors] if self.factors is not None else None, + constraints=self.constraints, + chunk_id=self.chunk_id) class BadTranslatorInput(TranslatorInput): @@ -656,7 +678,12 @@ def make_input_from_json_string(sentence_id: int, json_string: str) -> Translato return _bad_input(sentence_id, reason=json_string) else: factors = None - return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors) + constraints = jobj.get(C.JSON_CONSTRAINTS_KEY) + if isinstance(constraints, list) and len(constraints) > 0: + constraints = [list(data_io.get_tokens(constraint)) for constraint in constraints] + else: + constraints = None + return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors, constraints=constraints) except Exception as e: logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore @@ -1025,7 +1052,7 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu elif len(trans_input.tokens) == 0: translated_chunks.append(TranslatedChunk(id=input_idx, chunk_id=0, translation=empty_translation())) else: - # TODO(tdomhan): Remove branch without EOS with next major version bump, as future models with always be trained with source side EOS symbols + # TODO(tdomhan): Remove branch without EOS with next major version bump, as future models will always be trained with source side EOS symbols if self.source_with_eos: max_input_length_without_eos = self.max_input_length - C.SPACE_FOR_XOS # oversized input @@ -1056,6 +1083,12 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu else: input_chunks.append(trans_input) + if trans_input.constraints is not None: + logger.info("Input %d has %d %s: %s", trans_input.sentence_id, + len(trans_input.constraints), + "constraint" if len(trans_input.constraints) == 1 else "constraints", + ", ".join(" ".join(x) for x in trans_input.constraints)) + # Sort longest to shortest (to rather fill batches of shorter than longer sequences) input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.tokens), reverse=True) @@ -1091,17 +1124,20 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu return results - def _get_inference_input(self, trans_inputs: List[TranslatorInput]) -> Tuple[mx.nd.NDArray, int]: + def _get_inference_input(self, trans_inputs: List[TranslatorInput]) -> Tuple[mx.nd.NDArray, int, List[Optional[constrained.RawConstraintList]]]: """ - Returns NDArray of source ids (shape=(batch_size, bucket_key, num_factors)) and corresponding bucket_key. - Also checks correctness of translator inputs. + Assembles the numerical data for the batch. + This comprises an NDArray for the source sentences, the bucket key (padded source length), and a list of + raw constraint lists, one for each sentence in the batch. Each raw constraint list contains phrases in + the form of lists of integers in the target language vocabulary. :param trans_inputs: List of TranslatorInputs. - :return NDArray of source ids and bucket key. + :return NDArray of source ids (shape=(batch_size, bucket_key, num_factors)), bucket key, a list of raw constraint lists. """ - bucket_key = data_io.get_bucket(max(len(inp.tokens) for inp in trans_inputs), self.buckets_source) + bucket_key = data_io.get_bucket(max(len(inp.tokens) for inp in trans_inputs), self.buckets_source) source = mx.nd.zeros((len(trans_inputs), bucket_key, self.num_source_factors), ctx=self.context) + raw_constraints = [None for x in range(self.batch_size)] # type: List[Optional[constrained.RawConstraintList]] for j, trans_input in enumerate(trans_inputs): num_tokens = len(trans_input) @@ -1114,9 +1150,13 @@ def _get_inference_input(self, trans_inputs: List[TranslatorInput]) -> Tuple[mx. self.num_source_factors) for i, factor in enumerate(factors[:self.num_source_factors - 1], start=1): # fill in as many factors as there are tokens + source[j, :num_tokens, i] = data_io.tokens2ids(factor, self.source_vocabs[i])[:num_tokens] - return source, bucket_key + if trans_input.constraints is not None: + raw_constraints[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in trans_input.constraints] + + return source, bucket_key, raw_constraints def _make_result(self, trans_input: TranslatorInput, @@ -1162,16 +1202,19 @@ def _concat_translations(self, translations: List[Translation]) -> Translation: def _translate_nd(self, source: mx.nd.NDArray, - source_length: int) -> List[Translation]: + source_length: int, + raw_constraints: List[Optional[constrained.RawConstraintList]]) -> List[Translation]: """ Translates source of source_length, given a bucket_key. :param source: Source ids. Shape: (batch_size, bucket_key, num_factors). :param source_length: Bucket key. + :param raw_constraints: A list of optional constraint lists. :return: Sequence of translations. """ - return self._get_best_from_beam(*self._beam_search(source, source_length)) + + return self._get_best_from_beam(*self._beam_search(source, source_length, raw_constraints)) def _encode(self, sources: mx.nd.NDArray, source_length: int) -> List[ModelState]: """ @@ -1279,16 +1322,24 @@ def _prune(self, def _beam_search(self, source: mx.nd.NDArray, - source_length: int) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, - mx.nd.NDArray, mx.nd.NDArray, Optional[List[BeamHistory]]]: + source_length: int, + raw_constraint_list: List[Optional[constrained.RawConstraintList]]) -> Tuple[mx.nd.NDArray, + mx.nd.NDArray, + mx.nd.NDArray, + mx.nd.NDArray, + mx.nd.NDArray, + List[Optional[constrained.ConstrainedHypothesis]], + Optional[List[BeamHistory]]]: """ Translates multiple sentences using beam search. :param source: Source ids. Shape: (batch_size, bucket_key). :param source_length: Max source length. + :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs) that must appear in each output. :return List of lists of word ids, list of attentions, array of accumulated length-normalized negative log-probs. """ + # Length of encoded sequence (may differ from initial input length) encoded_source_length = self.models[0].encoder.get_encoded_seq_len(source_length) utils.check_condition(all(encoded_source_length == @@ -1338,8 +1389,15 @@ def _beam_search(self, # TODO: See note in method about migrating to pure MXNet when set operations are supported. # We currently convert source to NumPy and target ids back to NDArray. source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0] - vocab_slice_ids = mx.nd.array(self.restrict_lexicon.get_trg_ids(source_words.astype("int32").asnumpy()), - ctx=self.context) + vocab_slice_ids = self.restrict_lexicon.get_trg_ids(source_words.astype("int32").asnumpy()) + if any(raw_constraint_list): + # Add the constraint IDs to the list of permissibled IDs, and then project them into the reduced space + constraint_ids = np.array([word_id for sent in raw_constraint_list for phr in sent for word_id in phr]) + vocab_slice_ids = np.lib.arraysetops.union1d(vocab_slice_ids, constraint_ids) + full_to_reduced = dict((val, i) for i, val in enumerate(vocab_slice_ids)) + raw_constraint_list = [[[full_to_reduced[x] for x in phr] for phr in sent] for sent in raw_constraint_list] + + vocab_slice_ids = mx.nd.array(vocab_slice_ids, ctx=self.context) if vocab_slice_ids.shape[0] < self.beam_size + 1: # This fixes an edge case for toy models, where the number of vocab ids from the lexicon is @@ -1360,6 +1418,9 @@ def _beam_search(self, # (0) encode source sentence, returns a list model_states = self._encode(source, source_length) + # Initialize the beam to track constraint sets, where target-side lexical constraints are present + constraints = constrained.init_batch(raw_constraint_list, self.beam_size, self.start_id, self.vocab_target[C.EOS_SYMBOL]) + # Records items in the beam that are inactive. At the beginning (t==1), there is only one valid or active # item on the beam for each sentence inactive = mx.nd.ones((self.batch_size * self.beam_size), dtype='int32', ctx=self.context) @@ -1386,6 +1447,23 @@ def _beam_search(self, # far as the active beam size for each sentence. best_hyp_indices[:], best_word_indices[:], scores_accumulated[:, 0] = self.topk(scores) + # Constraints for constrained decoding are processed sentence by sentence + if any(raw_constraint_list): + best_hyp_indices, best_word_indices, scores_accumulated, \ + constraints, inactive = constrained.topk(self.batch_size, + self.beam_size, + inactive, + scores, + constraints, + best_hyp_indices, + best_word_indices, + scores_accumulated, + self.context) + + else: + # All rows are now active (after special treatment of start state at t=1) + inactive[:] = 0 + # Map from restricted to full vocab ids if needed if self.restrict_lexicon: best_word_indices[:] = vocab_slice_ids.take(best_word_indices) @@ -1398,8 +1476,6 @@ def _beam_search(self, newly_finished = all_finished - finished scores_accumulated = mx.nd.where(newly_finished, scores_accumulated / self.length_penalty(lengths), scores_accumulated) finished = all_finished - # All rows are now active (after special treatment of start state at t=1) - inactive[:] = 0 # (5) Prune out low-probability hypotheses. Pruning works by setting entries `inactive`. if self.beam_prune > 0.0: @@ -1451,6 +1527,8 @@ def _beam_search(self, for ms in model_states: ms.sort_state(best_hyp_indices) + logger.debug("Finished after %d / %d steps.", t + 1, max_output_length) + # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them). folded_accumulated_scores = scores_accumulated.reshape((self.batch_size, self.beam_size * scores_accumulated.shape[-1])) indices = mx.nd.argsort(folded_accumulated_scores, axis=1) @@ -1461,49 +1539,86 @@ def _beam_search(self, attentions = mx.nd.take(attentions, best_hyp_indices) scores_accumulated[:] = mx.nd.take(scores_accumulated, best_hyp_indices) finished = mx.nd.take(finished, best_hyp_indices) + constraints = [constraints[int(x.asscalar())] for x in best_hyp_indices] - return sequences, attentions, scores_accumulated, lengths, beam_histories + return sequences, attentions, scores_accumulated, lengths, finished, constraints, beam_histories def _get_best_from_beam(self, sequences: mx.nd.NDArray, attention_lists: mx.nd.NDArray, - accumulated_scores: mx.nd.NDArray, + seq_scores: mx.nd.NDArray, lengths: mx.nd.NDArray, - beam_histories: Optional[List[BeamHistory]]) -> List[Translation]: + finished: mx.nd.NDArray, + constraints: List[Optional[constrained.ConstrainedHypothesis]], + beam_histories: Optional[List[BeamHistory]] = None) -> List[Translation]: """ Return the best (aka top) entry from the n-best list. - :param sequences: Array of word ids. Shape: (batch_size * beam_size, bucket_key). + :param sequences: Array of word ids. Shape: (batch * beam, bucket_key). :param attention_lists: Array of attentions over source words. - Shape: (batch_size * self.beam_size, max_output_length, encoded_source_length). - :param accumulated_scores: Array of length-normalized negative log-probs. - :return: Top sequence, top attention matrix, top accumulated score (length-normalized - negative log-probs) and length. + Shape: (batch * beam, max_output_length, encoded_source_length). + :param seq_scores: Array of length-normalized negative log-probs.. + Shape: (batch * beam, 1) + :param lengths: The lengths of all items in the beam. Shape: (batch * beam). + :param finished: Marks completed items in the beam. Shape: (batch * beam). + :param constraints: The constraints for all items in the beam. Shape: (batch * beam). + :param beam_histories: The beam histories for each sentence in the batch. + :return: List of Translation objects containing all relevant information. """ utils.check_condition(sequences.shape[0] == attention_lists.shape[0] \ - == accumulated_scores.shape[0] == lengths.shape[0], "Shape mismatch") - # sequences & accumulated scores are in latest 'k-best order', thus 0th element is best - best = 0 - result = [] - for sent in range(self.batch_size): - idx = sent * self.beam_size + best - length = int(lengths[idx].asscalar()) - sequence = sequences[idx][:length].asnumpy().tolist() - # attention_matrix: (target_seq_len, source_seq_len) - attention_matrix = np.stack(attention_lists[idx].asnumpy()[:length, :], axis=0) - score = accumulated_scores[idx].asscalar() - if beam_histories is not None: - history = beam_histories[sent] - result.append(Translation(sequence, attention_matrix, score, [history])) - else: - result.append(Translation(sequence, attention_matrix, score)) - return result + == seq_scores.shape[0] == lengths.shape[0], "Shape mismatch") + + # Initialize the best_ids to the first item in each batch + best_ids = mx.nd.arange(0, self.batch_size * self.beam_size, self.beam_size, ctx=self.context) + + if any(constraints): + # For constrained decoding, select from items that have met all constraints (might not be finished) + unmet = mx.nd.array([c.num_needed() if c is not None else 0 for c in constraints], ctx=self.context) + filtered = mx.nd.where(unmet == 0, seq_scores[:, 0], self.inf_array_long) + filtered = filtered.reshape((self.batch_size, self.beam_size)) + best_ids += mx.nd.argmin(filtered, axis=1) + + histories = beam_histories if beam_histories is not None else [None] * self.batch_size + return [self._assemble_translation(*x) for x in zip(range(self.batch_size), + sequences[best_ids], + lengths[best_ids], + attention_lists[best_ids], + seq_scores[best_ids], + histories)] + + def _assemble_translation(self, + sentno: int, + sequence: mx.nd.NDArray, + length: mx.nd.NDArray, + attention_lists: mx.nd.NDArray, + seq_score: mx.nd.NDArray, + beam_history: List[Optional[BeamHistory]]) -> Translation: + """ + Takes a set of data pertaining to a single translated item, performs slightly different + processing on each, and merges it into a Translation object. + + :param sentno: The sentence number in the batch. + :param sequence: Array of word ids. Shape: (batch_size, bucket_key). + :param length: The length of the translated segment. + :param attention_lists: Array of attentions over source words. + Shape: (batch_size * self.beam_size, max_output_length, encoded_source_length). + :param seq_scores: Array of length-normalized negative log-probs. + :param beam_histories: The beam histories for each sentence in the batch. + :return: A Translation object. + """ + + length = int(length.asscalar()) + sequence = sequence[:length].asnumpy().tolist() + attention_matrix = np.stack(attention_lists.asnumpy()[:length, :], axis=0) + score = seq_score.asscalar() + return Translation(sequence, attention_matrix, score, beam_history) def _print_beam(self, sequences: mx.nd.NDArray, accumulated_scores: mx.nd.NDArray, finished: mx.nd.NDArray, inactive: mx.nd.NDArray, + constraints: List[Optional[constrained.ConstrainedHypothesis]], timestep: int) -> None: """ Prints the beam for debugging purposes. @@ -1519,5 +1634,6 @@ def _print_beam(self, # for each hypothesis, print its entire history score = accumulated_scores[i].asscalar() word_ids = [int(x.asscalar()) for x in sequences[i]] + unmet = constraints[i].num_needed() if constraints[i] is not None else -1 hypothesis = '----------' if inactive[i] else ' '.join([self.vocab_target_inv[x] for x in word_ids if x != 0]) - logger.info('%d %d %d %.2f %s', i, finished[i].asscalar(), inactive[i].asscalar(), score, hypothesis) + logger.info('%d %d %d %d %.2f %s', i+1, finished[i].asscalar(), inactive[i].asscalar(), unmet, score, hypothesis) diff --git a/sockeye/lexical_constraints.py b/sockeye/lexical_constraints.py new file mode 100644 index 000000000..13c3642b8 --- /dev/null +++ b/sockeye/lexical_constraints.py @@ -0,0 +1,474 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import copy +import logging +import re +import time + +from typing import Dict, List, Optional, Tuple, Set +from operator import attrgetter + +from . import constants as C +from . import utils + +import mxnet as mx +import numpy as np + +logger = logging.getLogger(__name__) + +# Represents a list of raw constraints for a sentence. Each constraint is a list of target-word IDs. +RawConstraintList = List[List[int]] + +class ConstrainedHypothesis: + """ + Represents a set of words and phrases that must appear in the output. + A constraint is of two types: sequence or non-sequence. + A non-sequence constraint is a single word and can therefore be followed by anything, whereas a sequence constraint must be followed by a particular word (the next word in the sequence). + This class also records which constraints have been met. + + A list of raw constraints is maintained internally as two parallel arrays. The following raw constraint + represents two phrases that must appear in the output: 14 and 19 35 14. + + raw constraint: [[14], [19, 35, 14]] + + This is represented internally as: + + constraints: [14 19 35 14] + is_sequence: [ 1 1 0 0 + + :param constraint_list: A list of zero or raw constraints (each represented as a list of integers). + :param eos_id: The end-of-sentence ID. + """ + def __init__(self, + constraint_list: RawConstraintList, + eos_id: int) -> None: + + # `constraints` records the words of the constraints, as a list (duplicates allowed). + # `is_sequence` is a parallel array that records, for each corresponding constraint, + # whether the current word is the non-final word of a phrasal constraint. + self.constraints = [] # type: List[int] + self.is_sequence = [] # type: List[bool] + for phrase in constraint_list: + self.constraints += phrase + self.is_sequence += [True] * len(phrase) + self.is_sequence[-1] = False + + self.eos_id = eos_id + + # no constraints have been met + self.met = [False for x in self.constraints] + self.last_met = -1 + + def __len__(self) -> int: + """ + :return: The number of constraints. + """ + return len(self.constraints) + + def __str__(self) -> str: + s = [] + for i, word_id in enumerate(self.constraints): + s.append(str(word_id) if self.met[i] is False else 'X') + if self.is_sequence[i]: + s.append('->') + return ' '.join(s) + + def size(self) -> int: + """ + :return: the number of constraints + """ + return len(self.constraints) + + def num_met(self) -> int: + """ + :return: the number of constraints that have been met. + """ + return sum(self.met) + + def num_needed(self) -> int: + """ + :return: the number of un-met constraints. + """ + return self.size() - self.num_met() + + def allowed(self) -> Set[int]: + """ + Returns the set of constrained words that could follow this one. + For unfinished phrasal constraints, it is the next word in the phrase. + In other cases, it is the list of all unmet constraints. + If all constraints are met, an empty set is returned. + + :return: The ID of the next required word, or -1 if any word can follow + """ + items = set() # type: Set[int] + # Add extensions of a started-but-incomplete sequential constraint + if self.last_met != -1 and self.is_sequence[self.last_met] == 1: + word_id = self.constraints[self.last_met + 1] + if word_id != self.eos_id or self.num_needed() == 1: + items.add(word_id) + + # Add all constraints that aren't non-initial sequences + else: + for i, word_id in enumerate(self.constraints): + if not self.met[i] and (i == 0 or not self.is_sequence[i - 1]): + if word_id != self.eos_id or self.num_needed() == 1: + items.add(word_id) + + return items + + def finished(self) -> bool: + """ + Return true if all the constraints have been met. + + :return: True if all the constraints are met. + """ + return self.num_needed() == 0 + + def is_valid(self, wordid) -> bool: + """ + Ensures is only generated when the hypothesis is completed. + + :param wordid: The wordid to validate. + :return: True if all constraints are already met or the word ID is not the EOS id. + """ + return self.finished() or wordid != self.eos_id or (self.num_needed() == 1 and self.eos_id in self.allowed()) + + def advance(self, word_id: int) -> 'ConstrainedHypothesis': + """ + Updates the constraints object based on advancing on word_id. + There is a complication, in that we may have started but not + yet completed a multi-word constraint. We need to allow constraints + to be added as unconstrained words, so if the next word is + invalid, we must "back out" of the current (incomplete) phrase, + re-setting all of its words as unmet. + + :param word_id: The word ID to advance on. + :return: A deep copy of the object, advanced on word_id. + """ + + obj = copy.deepcopy(self) + + # First, check if we're updating a sequential constraint. + if obj.last_met != -1 and obj.is_sequence[obj.last_met] == 1: + if word_id == obj.constraints[obj.last_met + 1]: + # Here, the word matches what we expect next in the constraint, so we update everything + obj.met[obj.last_met + 1] = True + obj.last_met += 1 + else: + # Here, the word is not the expected next word of the constraint, so we back out of the constraint. + index = obj.last_met + while obj.is_sequence[index]: + obj.met[index] = False + index -= 1 + obj.last_met = -1 + + # If not, check whether we're meeting a single-word constraint + else: + # Build a list from all constraints of tuples of the + # form (constraint, whether it's a non-initial sequential, whether it's been met) + constraint_tuples = list(zip(obj.constraints, [False] + obj.is_sequence[:-1], obj.met)) + # We are searching for an unmet constraint (word_id) that is not the middle of a phrase and is not met + query = (word_id, False, False) + try: + pos = constraint_tuples.index(query) + obj.met[pos] = True + obj.last_met = pos + except ValueError: + # query not found; identical but duplicated object will be returned + pass + + return obj + + +def init_batch(raw_constraints: List[Optional[RawConstraintList]], + beam_size: int, + start_id: int, + eos_id: int) -> List[Optional[ConstrainedHypothesis]]: + """ + :param raw_constraints: The list of raw constraints (list of list of IDs). + :param beam_size: The beam size. + :param start_id: The target-language vocabulary ID of the SOS symbol. + :param eos_id: The target-language vocabulary ID of the EOS symbol. + :return: A list of ConstrainedHypothesis objects (shape: (batch_size * beam_size,)). + """ + constraints = [None] * (len(raw_constraints) * beam_size) # type: List[Optional[ConstrainedHypothesis]] + if any(raw_constraints): + for i, raw_list in enumerate(raw_constraints): + num_constraints = sum([len(phrase) for phrase in raw_list]) if raw_list is not None else 0 + if num_constraints > 0: + hyp = ConstrainedHypothesis(raw_list, eos_id) + idx = i * beam_size + constraints[idx:idx+beam_size] = [hyp.advance(start_id) for x in range(beam_size)] + + return constraints + + +def get_bank_sizes(num_constraints: int, + beam_size: int, + candidate_counts: List[int]) -> List[int]: + """ + Evenly distributes the beam across the banks, where each bank is a portion of the beam devoted + to hypotheses having met the same number of constraints, 0..num_constraints. + After the assignment, banks with more slots than candidates are adjusted. + + :param num_constraints: The number of constraints. + :param beam_size: The beam size. + :param candidate_counts: The empirical counts of number of candidates in each bank. + :return: A distribution over banks. + """ + + num_banks = num_constraints + 1 + bank_size = beam_size // num_banks + remainder = beam_size - bank_size * num_banks + + # Distribute any remainder to the end + assigned = [bank_size for x in range(num_banks)] + assigned[-1] += remainder + + # Now, moving right to left, push extra allocation to earlier buckets. + # This encodes a bias for higher buckets, but if no candidates are found, space + # will be made in lower buckets. This may not be the best strategy, but it is important + # that you start pushing from the bucket that is assigned the remainder, for cases where + # num_constraints >= beam_size. + for i in reversed(range(num_banks)): + overfill = assigned[i] - candidate_counts[i] + if overfill > 0: + assigned[i] -= overfill + assigned[(i - 1) % num_banks] += overfill + + return assigned + + +class ConstrainedCandidate: + """ + Object used to hold candidates for the beam in topk(). + + :param row: The row in the scores matrix. + :param col: The column (word ID) in the scores matrix. + :param score: the associated accumulated score. + :param hypothesis: The ConstrainedHypothesis containing information about met constraints. + """ + + __slots__ = ('row', 'col', 'score', 'hypothesis') + + def __init__(self, + row: int, + col: int, + score: float, + hypothesis: ConstrainedHypothesis) -> None: + self.row = row + self.col = col + self.score = score + self.hypothesis = hypothesis + + def __hash__(self): + return hash((self.row, self.col)) + + def __eq__(self, other): + return self.row == other.row and self.col == other.col + + def __str__(self): + return '({}, {}, {}, {})'.format(self.row, self.col, self.score, self.hypothesis.num_met()) + + +def topk(batch_size: int, + beam_size: int, + inactive: mx.ndarray, + scores: mx.ndarray, + hypotheses: List[ConstrainedHypothesis], + best_ids: mx.ndarray, + best_word_ids: mx.ndarray, + seq_scores: mx.ndarray, + context: mx.context.Context) -> Tuple[np.array, np.array, np.array, List[ConstrainedHypothesis], mx.nd.array]: + """ + Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints. + These items are built from three different types: (1) the best items across the whole + scores matrix, (2) the set of words that must follow existing constraints, and (3) k-best items from each row. + + :param batch_size: The number of segments in the batch. + :param beam_size: The length of the beam for each segment. + :param inactive: Array listing inactive rows (shape: (beam_size,)). + :param scores: The scores array (shape: (beam_size, target_vocab_size)). + :param hypotheses: The list of hypothesis objects. + :param best_ids: The current list of best hypotheses (shape: (beam_size,)). + :param best_word_ids: The parallel list of best word IDs (shape: (beam_size,)). + :param seq_scores: (shape: (beam_size, 1)). + :param context: The MXNet device context. + :return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores, + the updated constrained hypotheses, and the updated set of inactive hypotheses. + """ + + for sentno in range(batch_size): + rows = slice(sentno * beam_size, (sentno + 1) * beam_size) + if hypotheses[rows.start] is not None and hypotheses[rows.start].size() > 0: + best_ids[rows], best_word_ids[rows], seq_scores[rows], \ + hypotheses[rows], inactive[rows] = _topk(beam_size, + inactive[rows], + scores[rows], + hypotheses[rows], + best_ids[rows] - rows.start, + best_word_ids[rows], + seq_scores[rows], + context) + + # offsetting since the returned smallest_k() indices were slice-relative + best_ids[rows] += rows.start + else: + # If there are no constraints for this sentence in the batch, everything stays + # the same, except we need to mark all hypotheses as active + inactive[rows] = 0 + + return (best_ids, best_word_ids, seq_scores, hypotheses, inactive) + +def _topk(beam_size: int, + inactive: mx.ndarray, + scores: mx.ndarray, + hypotheses: List[ConstrainedHypothesis], + best_ids: mx.ndarray, + best_word_ids: mx.ndarray, + sequence_scores: mx.ndarray, + context: mx.context.Context) -> Tuple[np.array, np.array, np.array, List[ConstrainedHypothesis], mx.nd.array]: + """ + Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints. + These items are built from three different types: (1) the best items across the whole + scores matrix, (2) the set of words that must follow existing constraints, and (3) k-best items from each row. + + :param beam_size: The length of the beam for each segment. + :param inactive: Array listing inactive rows (shape: (beam_size,)). + :param scores: The scores array (shape: (beam_size, target_vocab_size)). + :param hypotheses: The list of hypothesis objects. + :param best_ids: The current list of best hypotheses (shape: (beam_size,)). + :param best_word_ids: The parallel list of best word IDs (shape: (beam_size,)). + :param sequence_scores: (shape: (beam_size, 1)). + :param context: The MXNet device context. + :return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores, + the updated constrained hypotheses, and the updated set of inactive hypotheses. + """ + + num_constraints = hypotheses[0].size() + + candidates = set() + # (1) Add all of the top-k items (which were passed) in as long as they pass the constraints + for row, col, seq_score in zip(best_ids, best_word_ids, sequence_scores): + row = int(row.asscalar()) + col = int(col.asscalar()) + if hypotheses[row].is_valid(col): + seq_score = float(seq_score.asscalar()) + new_item = hypotheses[row].advance(col) + cand = ConstrainedCandidate(row, col, seq_score, new_item) + candidates.add(cand) + + # For each hypothesis, we add (2) all the constraints that could follow it and + # (3) the best item (constrained or not) in that row + best_next = mx.ndarray.argmin(scores, axis=1) + for row in range(beam_size): + if inactive[row]: + continue + + hyp = hypotheses[row] + + # (2) add all the constraints that could extend this + nextones = hyp.allowed() + + # (3) add the single-best item after this (if it's valid) + col = int(best_next[row].asscalar()) + if hyp.is_valid(col): + nextones.add(col) + + # Now, create new candidates for each of these items + for col in nextones: + new_item = hyp.advance(col) + score = scores[row, col].asscalar() + cand = ConstrainedCandidate(row, col, score, new_item) + candidates.add(cand) + + # Sort the candidates. After allocating the beam across the banks, we will pick the top items + # for each bank from this list + sorted_candidates = sorted(candidates, key=attrgetter('score')) + + # The number of hypotheses in each bank + counts = [0 for x in range(num_constraints + 1)] + for cand in sorted_candidates: + counts[cand.hypothesis.num_met()] += 1 + + # Adjust allocated bank sizes if there are too few candidates in any of them + bank_sizes = get_bank_sizes(num_constraints, beam_size, counts) + + # Sort the candidates into the allocated banks + pruned_candidates = [] # type: List[ConstrainedCandidate] + for i, cand in enumerate(sorted_candidates): + bank = cand.hypothesis.num_met() + + if bank_sizes[bank] > 0: + pruned_candidates.append(cand) + bank_sizes[bank] -= 1 + + inactive[:len(pruned_candidates)] = 0 + + # Pad the beam so array assignment still works + if len(pruned_candidates) < beam_size: + inactive[len(pruned_candidates):] = 1 + pruned_candidates += [pruned_candidates[len(pruned_candidates)-1]] * (beam_size - len(pruned_candidates)) + + return (np.array([x.row for x in pruned_candidates]), + np.array([x.col for x in pruned_candidates]), + np.array([[x.score] for x in pruned_candidates]), + [x.hypothesis for x in pruned_candidates], + inactive) + + +def main(): + """ + Usage: python3 -m sockeye.lexical_constraints [--bpe BPE_MODEL] + + Reads sentences and constraints on STDIN (tab-delimited) and generates the JSON format that can be used when passing `--json-input` + to sockeye.translate. + + e.g., + + echo -e "Dies ist ein Test .\tThis is\ttest" | python3 -m sockeye.lexical_constraints + + will produce the following JSON object: + + { "text": "Dies ist ein Test .", "constraints": ["This is", "test"] } + + Make sure you apply all preprocessing (tokenization, BPE, etc.) to both the source and the target-side constraints. + You can then translate this object by passing it to Sockeye on STDIN as follows: + + python3 -m sockeye.translate -m /path/to/model --json-input --beam-size 20 --beam-prune 20 + + (Note the recommended Sockeye parameters). + """ + + import argparse + import sys + import json + + parser = argparse.ArgumentParser(description='Generate lexical constraint JSON format for Sockeye') + args = parser.parse_args() + + for line in sys.stdin: + line = line.rstrip() + + # Constraints are in fields 2+ + source, *constraints = line.split('\t') + + obj = { 'text': source } + if len(constraints) > 0: + obj['constraints'] = constraints + + print(json.dumps(obj, ensure_ascii=False), flush=True) + +if __name__ == '__main__': + main() diff --git a/sockeye/translate.py b/sockeye/translate.py index 38054f50d..b13175b84 100644 --- a/sockeye/translate.py +++ b/sockeye/translate.py @@ -105,63 +105,66 @@ def run_translate(args: argparse.Namespace): read_and_translate(translator=translator, output_handler=output_handler, chunk_size=args.chunk_size, - inp=args.input, - inp_factors=args.input_factors, - json_input=args.json_input) + input_file=args.input, + input_factors=args.input_factors, + input_is_json=args.json_input) -def make_inputs(inp: Optional[str], +def make_inputs(input_file: Optional[str], translator: inference.Translator, - json_input: bool, - inp_factors: Optional[List[str]] = None) -> Generator[inference.TranslatorInput, None, None]: + input_is_json: bool, + input_factors: Optional[List[str]] = None) -> Generator[inference.TranslatorInput, None, None]: """ Generates TranslatorInput instances from input. If input is None, reads from stdin. If num_input_factors > 1, the function will look for factors attached to each token, separated by '|'. If source is not None, reads from the source file. If num_source_factors > 1, num_source_factors source factor filenames are required. - :param inp: The source file (possibly None). + :param input_file: The source file (possibly None). :param translator: Translator that will translate each line of input. - :param json_input: Whether the input is in json format. - :param inp_factors: Source factor files. + :param input_is_json: Whether the input is in json format. + :param input_factors: Source factor files. :return: TranslatorInput objects. """ - if inp is None: - check_condition(inp_factors is None, "Translating from STDIN, not expecting any factor files.") + if input_file is None: + check_condition(input_factors is None, "Translating from STDIN, not expecting any factor files.") for sentence_id, line in enumerate(sys.stdin, 1): - if json_input: + if input_is_json: yield inference.make_input_from_json_string(sentence_id=sentence_id, json_string=line) else: yield inference.make_input_from_factored_string(sentence_id=sentence_id, factored_string=line, translator=translator) else: - inp_factors = [] if inp_factors is None else inp_factors - inputs = [inp] + inp_factors + input_factors = [] if input_factors is None else input_factors + inputs = [input_file] + input_factors check_condition(translator.num_source_factors == len(inputs), "Model(s) require %d factors, but %d given (through --input and --input-factors)." % ( translator.num_source_factors, len(inputs))) with ExitStack() as exit_stack: streams = [exit_stack.enter_context(data_io.smart_open(i)) for i in inputs] for sentence_id, inputs in enumerate(zip(*streams), 1): - yield inference.make_input_from_multiple_strings(sentence_id=sentence_id, strings=list(inputs)) + if input_is_json: + yield inference.make_input_from_json_string(sentence_id=sentence_id, json_string=inputs[0]) + else: + yield inference.make_input_from_multiple_strings(sentence_id=sentence_id, strings=list(inputs)) def read_and_translate(translator: inference.Translator, output_handler: OutputHandler, chunk_size: Optional[int], - inp: Optional[str] = None, - inp_factors: Optional[List[str]] = None, - json_input: bool = False) -> None: + input_file: Optional[str] = None, + input_factors: Optional[List[str]] = None, + input_is_json: bool = False) -> None: """ Reads from either a file or stdin and translates each line, calling the output_handler with the result. :param output_handler: Handler that will write output to a stream. :param translator: Translator that will translate each line of input. :param chunk_size: The size of the portion to read at a time from the input. - :param inp: Optional path to file which will be translated line-by-line if included, if none use stdin. - :param inp_factors: Optional list of paths to files that contain source factors. - :param json_input: Whether the input is in json format. + :param input_file: Optional path to file which will be translated line-by-line if included, if none use stdin. + :param input_factors: Optional list of paths to files that contain source factors. + :param input_is_json: Whether the input is in json format. """ batch_size = translator.batch_size if chunk_size is None: @@ -180,7 +183,7 @@ def read_and_translate(translator: inference.Translator, logger.info("Translating...") total_time, total_lines = 0.0, 0 - for chunk in grouper(make_inputs(inp, translator, json_input, inp_factors), size=chunk_size): + for chunk in grouper(make_inputs(input_file, translator, input_is_json, input_factors), size=chunk_size): chunk_time = translate(output_handler, chunk, translator) total_lines += len(chunk) total_time += chunk_time @@ -193,7 +196,8 @@ def read_and_translate(translator: inference.Translator, logger.info("Processed 0 lines.") -def translate(output_handler: OutputHandler, trans_inputs: List[inference.TranslatorInput], +def translate(output_handler: OutputHandler, + trans_inputs: List[inference.TranslatorInput], translator: inference.Translator) -> float: """ Translates each line from source_data, calling output handler after translating a batch. diff --git a/test/common.py b/test/common.py index 5af768354..ac84921a7 100644 --- a/test/common.py +++ b/test/common.py @@ -11,6 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +import json import logging import os import random @@ -151,7 +152,8 @@ def tmp_digits_dataset(prefix: str, test_line_count: int, test_line_count_empty: int, test_max_length: int, sort_target: bool = False, seed_train: int = 13, seed_dev: int = 13, - with_source_factors: bool = False): + with_source_factors: bool = False, + with_target_constraints: bool = False): with TemporaryDirectory(prefix=prefix) as work_dir: # Simple digits files for train/dev data train_source_path = os.path.join(work_dir, "train.src") @@ -185,6 +187,29 @@ def tmp_digits_dataset(prefix: str, data['dev_source_factors'] = [dev_factor_path] data['test_source_factors'] = [test_factor_path] + if with_target_constraints: + # When using constrained decoding, rewrite the source file. Generating a mixture of + # sentences with and without constraints here is critical, since this can happen in production + # and also introduces sometimes some unanticipated interactions. + new_sources = [] + for sentno, (source, target) in enumerate(zip(open(data['test_source']), open(data['test_target']))): + target_words = target.rstrip().split() + target_len = len(target_words) + source_len = len(source.rstrip().split()) + new_source = { 'text': source.rstrip() } + # From the odd-numbered sentences that are not too long, create constraints. We do + # only odds to ensure we get batches with mixed constraints / lack of constraints. + if target_len > 0 and sentno % 2 == 0: + start_pos = 0 + end_pos = min(target_len, 3) + constraint = ' '.join(target_words[start_pos:end_pos]) + new_source['constraints'] = [constraint] + new_sources.append(json.dumps(new_source)) + + with open(data['test_source'], 'w') as out: + for json_line in new_sources: + print(json_line, file=out) + yield data diff --git a/test/integration/test_seq_copy_int.py b/test/integration/test_seq_copy_int.py index 4c2601932..f453ea7db 100644 --- a/test/integration/test_seq_copy_int.py +++ b/test/integration/test_seq_copy_int.py @@ -12,6 +12,7 @@ # permissions and limitations under the License. import pytest +import random import sockeye.constants as C from test.common import run_train_translate, tmp_digits_dataset @@ -31,7 +32,7 @@ " --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence " " --decode-and-evaluate 0", "--beam-size 2", - True, False, False), + True, False, False, True), # "Kitchen sink" LSTM encoder-decoder with attention ("--encoder rnn --decoder rnn --num-layers 4:2 --rnn-cell-type lstm --rnn-num-hidden 16" " --rnn-residual-connections" @@ -44,7 +45,7 @@ " --rnn-h2h-init orthogonal_stacked --batch-type sentence --decode-and-evaluate 0" " --learning-rate-decay-param-reset --weight-normalization --source-factors-num-embed 5", "--beam-size 2", - False, True, True), + False, True, True, False), # Convolutional embedding encoder + LSTM encoder-decoder with attention ("--encoder rnn-with-conv-embed --decoder rnn --conv-embed-max-filter-width 3 --conv-embed-num-filters 4:4:8" " --conv-embed-pool-stride 2 --conv-embed-num-highway-layers 1 --num-layers 1 --rnn-cell-type lstm" @@ -52,7 +53,7 @@ " --optimized-metric perplexity --max-updates 10 --checkpoint-frequency 10 --optimizer adam --batch-type sentence" " --initial-learning-rate 0.01 --decode-and-evaluate 0", "--beam-size 2", - False, False, False), + False, False, False, False), # Transformer encoder, GRU decoder, mhdot attention ("--encoder transformer --decoder rnn --num-layers 2:1 --rnn-cell-type gru --rnn-num-hidden 16 --num-embed 8:16" " --transformer-attention-heads 2 --transformer-model-size 8" @@ -62,7 +63,7 @@ " --weight-init-xavier-factor-type avg --weight-init-scale 3.0 --embed-weight-init normal --batch-type sentence" " --decode-and-evaluate 0", "--beam-size 2", - False, True, False), + False, True, False, False), # LSTM encoder, Transformer decoder ("--encoder rnn --decoder transformer --num-layers 2:2 --rnn-cell-type lstm --rnn-num-hidden 16 --num-embed 16" " --transformer-attention-heads 2 --transformer-model-size 16" @@ -70,7 +71,7 @@ " --batch-size 8 --max-updates 10 --batch-type sentence --decode-and-evaluate 0" " --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01", "--beam-size 3", - False, True, False), + False, True, False, False), # Full transformer ("--encoder transformer --decoder transformer" " --num-layers 3 --transformer-attention-heads 2 --transformer-model-size 16 --num-embed 16" @@ -81,7 +82,7 @@ " --batch-size 8 --max-updates 10 --batch-type sentence --decode-and-evaluate 0" " --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01", "--beam-size 2", - True, False, False), + True, False, False, False), # Full transformer with source factor ("--encoder transformer --decoder transformer" " --num-layers 3 --transformer-attention-heads 2 --transformer-model-size 16 --num-embed 16" @@ -91,14 +92,14 @@ " --batch-size 8 --max-updates 10 --batch-type sentence --decode-and-evaluate 0" " --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01 --source-factors-num-embed 4", "--beam-size 2", - True, False, True), + True, False, True, False), # 3-layer cnn ("--encoder cnn --decoder cnn " " --batch-size 16 --num-layers 3 --max-updates 10 --checkpoint-frequency 10" " --cnn-num-hidden 32 --cnn-positional-embedding-type fixed" " --optimizer adam --initial-learning-rate 0.001 --batch-type sentence --decode-and-evaluate 0", "--beam-size 2", - True, False, False), + True, False, False, False), # Vanilla LSTM like above but activating LHUC. In the normal case you would # start with a trained system instead of a random initialized one like here. ("--encoder rnn --decoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 16 --num-embed 8 --rnn-attention-type mlp" @@ -107,16 +108,17 @@ " --loss cross-entropy --optimized-metric perplexity --max-updates 10" " --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01 --lhuc all", "--beam-size 2", - True, False, False)] + True, False, False, False)] -@pytest.mark.parametrize("train_params, translate_params, restrict_lexicon, use_prepared_data, use_source_factors", +@pytest.mark.parametrize("train_params, translate_params, restrict_lexicon, use_prepared_data, use_source_factors, use_constrained_decoding", ENCODER_DECODER_SETTINGS) def test_seq_copy(train_params: str, translate_params: str, restrict_lexicon: bool, use_prepared_data: bool, - use_source_factors: bool): + use_source_factors: bool, + use_constrained_decoding: bool): """Task: copy short sequences of digits""" with tmp_digits_dataset(prefix="test_seq_copy", @@ -127,10 +129,12 @@ def test_seq_copy(train_params: str, test_line_count=_TEST_LINE_COUNT, test_line_count_empty=_TEST_LINE_COUNT_EMPTY, test_max_length=_TEST_MAX_LENGTH, - sort_target=False) as data: + sort_target=False, + with_source_factors=use_source_factors, + with_target_constraints=use_constrained_decoding) as data: - # Test model configuration, including the output equivalence of batch and no-batch decoding - translate_params_batch = translate_params + " --batch-size 2" + # Only one of these is supported at a time in the tests + assert not (use_source_factors and use_constrained_decoding) # When using source factors train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None @@ -139,6 +143,12 @@ def test_seq_copy(train_params: str, dev_source_factor_paths = [data['validation_source']] test_source_factor_paths = [data['test_source']] + if use_constrained_decoding: + translate_params += " --json-input" + + # Test model configuration, including the output equivalence of batch and no-batch decoding + translate_params_batch = translate_params + " --batch-size 2" + # Ignore return values (perplexity and BLEU) for integration test run_train_translate(train_params=train_params, translate_params=translate_params, @@ -147,7 +157,7 @@ def test_seq_copy(train_params: str, train_target_path=data['target'], dev_source_path=data['validation_source'], dev_target_path=data['validation_target'], - test_source_path=data['test_source'], + test_source_path=data['test_target'], test_target_path=data['test_target'], train_source_factor_paths=train_source_factor_paths, dev_source_factor_paths=dev_source_factor_paths, diff --git a/test/unit/test_constraints.py b/test/unit/test_constraints.py new file mode 100644 index 000000000..b5541d1b2 --- /dev/null +++ b/test/unit/test_constraints.py @@ -0,0 +1,193 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import json +from unittest.mock import Mock + +import mxnet as mx +import numpy as np +import pytest +from math import ceil + +import sockeye.constants as C +import sockeye.data_io +import sockeye.inference +from sockeye.utils import SockeyeError +from sockeye.lexical_constraints import init_batch, get_bank_sizes, topk, ConstrainedHypothesis + +BOS_ID = 2 +EOS_ID = 3 + +def mock_translator(num_source_factors: int): + t_mock = Mock(sockeye.inference.Translator) + t_mock.num_source_factors = num_source_factors + return t_mock + + +""" +Test how the banks are allocated. Given a number of constraints (C), the beam size (k) +a number of candidates for each bank [0..C], return the allocation of the k spots of the +beam to the banks. +""" +@pytest.mark.parametrize("num_constraints, beam_size, counts, expected_allocation", + [ + # no constraints: allocate all items to bin 0 + (0, 5, [5], [5]), + # 1 constraints, but no candidates, so 0 alloc + (1, 5, [5,0], [5,0]), + # 1 constraint, but 1 candidate, so 1 alloc + (1, 5, [5,1], [4,1]), + # 1 constraint, > k candidates for each (impossible, but ok), even alloc, extra goes to last + (1, 5, [10,10], [2,3]), + # 1 constraint, > k candidates for each (impossible, but ok), even alloc, extra goes to last + (1, 5, [1,10], [1,4]), + # 2 constraints, no candidates + (2, 5, [5,0,0], [5,0,0]), + # 2 constraints, some candidates + (2, 10, [5,0,7], [5,0,5]), + # 2 constraints, some candidates + (2, 10, [5,1,7], [5,1,4]), + # more constraints than beam spots: allocate to last two spots + (3, 2, [1,0,1,1], [0,0,1,1]), + # more constraints than beam spots: slots allocated empirically + (3, 2, [1,0,1,0], [1,0,1,0]), + # more constraints than beam spots: all slots to last bank + (3, 2, [4,2,2,3], [0,0,0,2]), + ]) +def test_constraints_bank_allocation(num_constraints, beam_size, counts, expected_allocation): + allocation = get_bank_sizes(num_constraints, beam_size, counts) + assert sum(allocation) == beam_size + assert allocation == expected_allocation + + +""" +Make sure the internal representation is correct. +For the internal representation, the list of phrasal constraints is concatenated, and then +a parallel array is used to mark which words are part of a phrasal constraint. +""" +@pytest.mark.parametrize("raw_constraints, internal_constraints, internal_is_sequence", + [ + # No constraints + ([], [], []), + # One single-word constraint + ([[17]], [17], [0]), + # Multiple multiple-word constraints. + ([[11, 12], [13, 14]], [11, 12, 13, 14], [True, False, True, False]), + # Multiple constraints + ([[11, 12, 13], [14], [15]], [11, 12, 13, 14, 15], [True, True, False, False, False]), + ]) +def test_constraints_repr(raw_constraints, internal_constraints, internal_is_sequence): + hyp = ConstrainedHypothesis(raw_constraints, EOS_ID) + assert hyp.constraints == internal_constraints + assert hyp.is_sequence == internal_is_sequence + + +""" +Tests many of the ConstrainedHypothesis functions. +""" +@pytest.mark.parametrize("raw_constraints, met, unmet", + [ + # No constraints + ([], [], []), + # Single simple unmet constraint + ([[17]], [], [17]), + # Single simple met constraint + ([[17]], [17], []), + # Met first word of a phrasal constraint, return just next word of phrasal + ([[11, 12], [13, 14]], [11], [12, 13, 14]), + # Completed phrase, have only single-word ones + ([[11, 12, 13], [14], [15]], [11, 12, 13], [14, 15]), + # Same word twice + ([[11], [11]], [], [11, 11]), + ]) +def test_constraints_logic(raw_constraints, met, unmet): + hyp = ConstrainedHypothesis(raw_constraints, EOS_ID) + # record these ones as met + for word_id in met: + hyp = hyp.advance(word_id) + + assert hyp.num_needed() == len(unmet) + assert hyp.finished() == (len(unmet) == 0) + assert hyp.is_valid(EOS_ID) == (hyp.finished() or (len(unmet) == 1 and EOS_ID in unmet)) + + +""" +Test the allowed() function, which returns the set of unmet constraints that can be generated. +When inside a phrase, this is only the next word of the phrase. Otherwise, it is all unmet constraints. +""" +@pytest.mark.parametrize("raw_constraints, met, allowed", + [ + # No constraints + ([], [], []), + # Single simple unmet constraint + ([[17]], [], [17]), + # Single simple met constraint + ([[17]], [17], []), + # Met first word of a phrasal constraint, return just next word of phrasal + ([[11, 12], [13, 14]], [11], [12]), + # Completed phrase, have only single-word ones + ([[11, 12, 13], [14], [15]], [11, 12, 13], [14, 15]), + # Same word twice, nothing met, return + ([[11], [11]], [], [11]), + # Same word twice, met, still returns once + ([[11], [11]], [11], [11]), + # Same word twice, met twice + ([[11], [11]], [11, 11], []), + # EOS, allowed + ([[42, EOS_ID]], [42], [EOS_ID]), + # EOS, not allowed + ([[42, EOS_ID]], [], [42]), + ]) +def test_constraints_allowed(raw_constraints, met, allowed): + hyp = ConstrainedHypothesis(raw_constraints, EOS_ID) + # record these ones as met + for word_id in met: + hyp = hyp.advance(word_id) + + assert hyp.allowed() == set(allowed) + assert hyp.num_met() == len(met) + assert hyp.num_needed() == hyp.size() - hyp.num_met() + + + + + +""" +Ensures that batches are initialized correctly. +Each line here is a tuple containing a list (for each sentence in the batch) of RawConstraintLists, +which are lists of list of integer IDs representing the constraints for the sentence. +""" +@pytest.mark.parametrize("raw_constraint_lists", + [ ([None, None, None, None]), + ([[[17]], None]), + ([None, [[17]]]), + ([[[17], [11, 12]], [[17]], None]), + ([None, [[17], [11, 12]], [[17]], None]), + ]) +def test_constraints_init_batch(raw_constraint_lists): + beam_size = 4 # arbitrary + + constraints = init_batch(raw_constraint_lists, beam_size, BOS_ID, EOS_ID) + assert len(raw_constraint_lists) * beam_size == len(constraints) + + # Iterate over sentences in the batch + for raw_constraint_list, constraint in zip(raw_constraint_lists, constraints[::beam_size]): + if raw_constraint_list is None: + assert constraint is None + else: + # The number of constraints is the sum of the length of the lists in the raw constraint list + assert constraint.size() == sum([len(phr) for phr in raw_constraint_list]) + + # No constraints are met unless the start_id happened to be at the start of a constraint + num_met = 1 if any([phr[0] == BOS_ID for phr in raw_constraint_list]) else 0 + assert constraint.num_met() == num_met diff --git a/test/unit/test_output_handler.py b/test/unit/test_output_handler.py index 39a1510e4..62cae71ac 100644 --- a/test/unit/test_output_handler.py +++ b/test/unit/test_output_handler.py @@ -18,7 +18,7 @@ import sockeye.output_handler stream_handler_tests = [(sockeye.output_handler.StringOutputHandler(io.StringIO()), - TranslatorInput(sentence_id=0, tokens=[], factors=[]), + TranslatorInput(sentence_id=0, tokens=[], factors=[], constraints=[]), TranslatorOutput(id=0, translation="ein Test", tokens=None, attention_matrix=None, score=0.), diff --git a/test/unit/test_translate.py b/test/unit/test_translate.py index 02f04f8c9..8ef281c8a 100644 --- a/test/unit/test_translate.py +++ b/test/unit/test_translate.py @@ -49,7 +49,7 @@ def test_translate_by_file(mock_file, mock_translator, mock_output_handler): mock_translator.num_source_factors = 1 mock_translator.batch_size = 1 sockeye.translate.read_and_translate(translator=mock_translator, output_handler=mock_output_handler, - chunk_size=2, inp='/dev/null', inp_factors=None) + chunk_size=2, input_file='/dev/null', input_factors=None) # Ensure translate gets called once. Input here will be a dummy mocked result, so we'll ignore it. assert mock_translator.translate.call_count == 1 diff --git a/tutorials/constraints/README.md b/tutorials/constraints/README.md new file mode 100644 index 000000000..f64e00849 --- /dev/null +++ b/tutorials/constraints/README.md @@ -0,0 +1,39 @@ +# Decoding with lexical constraints + +Lexical constraints provide a way to force the model to include certain words in the output. +Given a set of constraints, the decoder will find the best output that includes the constraints. +This file describes how to use lexical constraints; for more technical information, please see our paper: + + Fast Lexically Constrained Decoding With Dynamic Beam Allocation for Neural Machine Translation + Matt Post & David Vilar + [NAACL 2018](http://naacl2018.org/) + [PDF](https://arxiv.org/pdf/1804.06609.pdf) + +## Example + +You need a [trained model](../wmt/README.md). + +You need to be careful to apply the same preprocessing to your test data that you applied at training time, including +any [subword processing](http://github.com/rsennrich/subword-nmt), since Sockeye itself does not do this. + +Constraints must be encoded with a JSON object. +This JSON object can be produced with the provided script: + + echo -e "This is a test .\tconstraint\tmulti@@ word const@@ raint" \ + | python3 -m sockeye.lexical_constraints + +The script creates a Python object with the constraints encoded as follows (except on one line): + + { 'text': 'This is a test .', + 'constraints': ['constr@@ aint', + 'multi@@ word constr@@ aint'] } + +You can pass the output of this to Sockeye. Make sure that you specify `--json-input` so that Sockeye knows to parse the +input (without that flag, it will treat the JSON input as a regular sentence). We also recommend that you increase the +beam a little bit and enable beam pruning: + + echo -e "This is a test .\tconstraint\tmultiword constraint" \ + | python3 -m sockeye.lexical_constraints \ + | python3 -m sockeye.translate -m /path/to/model --json-input --beam-size 20 --beam-prune 20 [other args] + +You will get a translation with the required constraints as part of the output. diff --git a/typechecked-files b/typechecked-files index 556dd3814..db56a0458 100644 --- a/typechecked-files +++ b/typechecked-files @@ -15,6 +15,7 @@ sockeye/inference.py sockeye/init_embedding.py sockeye/initializer.py sockeye/layers.py +sockeye/lexical_constraints.py sockeye/lexicon.py sockeye/log.py sockeye/loss.py