Skip to content

Commit

Permalink
Add support for source factor models (#275)
Browse files Browse the repository at this point in the history
Source factors are enabled by passing --source-factors file1 [file2 ...] (-sf), where file1, etc. are token-parallel to the source (-s).
This option can be passed both to sockeye.train or in the data preparation step, if data sharding is used.
An analogous parameter, --validation-source-factors, is used to pass factors for validation data.
The flag --source-factors-num-embed D1 [D2 ...] denotes the embedding dimensions.
These are concatenated with the source word dimension (--num-embed), which can continue to be tied to the target (--weight-tying --weight-tying-type=src_trg).

At test time, the input sentence and its factors can be passed by multiple parallel files (--input and --input-factors) or through stdin with token-level annotations, separated by |. Another way is to send a string-serialized JSON object to the CLI through stdin which needs to have a top-level key called 'text' and optionally a key 'factors' of type List[str].

* Cleanup of vocab functions

* Simplified vocab logic a bit. Removed pickle functionality since it has been deprecated for long

* Refactor so that first factor corresponds to the source surface form (e.g. configs by default set num_factors to at least 1)

* fixed a TODO. slightly reworded the changelog

* Reworked inference interface. Added a bunch of TranslatorInput factory functions (including json)

* Removed max_seq_len_{source,target} from ModelConfig

* Separate data statistics relevant for inference from data information relevant only for training.

* Bumped Major Version to 1.17.0

* Do not throw exceptions while translating (#294)

* Remove bias parameters in Transformer attention layers as they bring no benefit. (#296)
  • Loading branch information
mjpost authored and fhieber committed Feb 19, 2018
1 parent 8263615 commit 5dcb60d
Show file tree
Hide file tree
Showing 29 changed files with 1,478 additions and 722 deletions.
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.17.0]
### Added
- Source factors, as described in

Linguistic Input Features Improve Neural Machine Translation (Sennrich \& Haddow, WMT 2016)
[PDF](http://www.aclweb.org/anthology/W16-2209.pdf) [bibtex](http://www.aclweb.org/anthology/W16-2209.bib)

Additional source factors are enabled by passing `--source-factors file1 [file2 ...]` (`-sf`), where file1, etc. are
token-parallel to the source (`-s`).
An analogous parameter, `--validation-source-factors`, is used to pass factors for validation data.
The flag `--source-factors-num-embed D1 [D2 ...]` denotes the embedding dimensions and is required if source factor
files are given. Factor embeddings are concatenated to the source embeddings dimension (`--num-embed`).

At test time, the input sentence and its factors can be passed in via STDIN or command-line arguments.
- For STDIN, the input and factors should be in a token-based factored format, e.g.,
`word1|factor1|factor2|... w2|f1|f2|... ...1`.
- You can also use file arguments, which mirrors training: `--input` takes the path to a file containing the source,
and `--input-factors` a list of files containing token-parallel factors.
At test time, an exception is raised if the number of expected factors does not
match the factors passed along with the input.

- Removed bias parameters from multi-head attention layers of the transformer.

## [1.16.6]
### Changed
Expand Down
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ It implements state-of-the-art encoder-decoder architectures, such as

If you use Sockeye, please cite:

Felix Hieber, Tobias Domhan, Michael Denkowski, David Vilar, Artem Sokolov, Ann Clifton and Matt Post (2017):
Felix Hieber, Tobias Domhan, Michael Denkowski, David Vilar, Artem Sokolov, Ann Clifton and Matt Post (2017):
[Sockeye: A Toolkit for Neural Machine Translation](https://arxiv.org/abs/1712.05690). In eprint arXiv:cs-CL/1712.05690.

```
Expand All @@ -36,7 +36,7 @@ If you are interested in collaborating or have any questions, please submit a pu
You can also send questions to *sockeye-dev-at-amazon-dot-com*.

Recent developments and changes are tracked in our [changelog](https://github.com/awslabs/sockeye/blob/master/CHANGELOG.md).

## Dependencies

Sockeye requires:
Expand Down Expand Up @@ -117,10 +117,9 @@ In general you can install all optional dependencies from the Sockeye source fol

### Running sockeye

After installation, command line tools such as *sockeye-train, sockeye-translate, sockeye-average*
and *sockeye-embeddings* are available. Alternatively, if the sockeye directory is on your
PYTHONPATH you can run the modules
directly. For example *sockeye-train* can also be invoked as
After installation, command line tools such as *sockeye-train, sockeye-translate, sockeye-average* and *sockeye-embeddings* are available.
Alternatively, if the sockeye directory is on your`$PYTHONPATH` you can run the modules directly.
For example *sockeye-train* can also be invoked as
```bash
> python -m sockeye.train <args>
```
Expand All @@ -129,8 +128,8 @@ directly. For example *sockeye-train* can also be invoked as

### Train

In order to train your first Neural Machine Translation model you will need two sets of parallel files: one for training
and one for validation. The latter will be used for computing various metrics during training.
In order to train your first Neural Machine Translation model you will need two sets of parallel files: one for training
and one for validation. The latter will be used for computing various metrics during training.
Each set should consist of two files: one with source sentences and one with target sentences (translations).
Both files should have the same number of lines, each line containing a single
sentence. Each sentence should be a whitespace delimited list of tokens.
Expand All @@ -145,20 +144,20 @@ Say you wanted to train a RNN German-to-English translation model, then you woul
--output <model_dir>
```

After training the directory *<model_dir>* will contain all model artifacts such as parameters and model
After training the directory *<model_dir>* will contain all model artifacts such as parameters and model
configuration. The default setting is to train a 1-layer LSTM model with attention.


### Translate

Input data for translation should be in the same format as the training data (tokenization, preprocessing scheme).
You can translate as follows:
You can translate as follows:

```bash
> python -m sockeye.translate --models <model_dir> --use-cpu
```

This will take the best set of parameters found during training and then translate strings from STDIN and
This will take the best set of parameters found during training and then translate strings from STDIN and
write translations to STDOUT.

For more detailed examples check out our user documentation.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.16.6'
__version__ = '1.17.0'
41 changes: 39 additions & 2 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
Defines commandline arguments for the main CLIs with reasonable defaults.
"""
import argparse
import sys
import os
import sys
from typing import Callable, Optional

from sockeye.lr_scheduler import LearningRateSchedulerFixedStep
Expand Down Expand Up @@ -256,6 +256,12 @@ def add_training_data_args(params, required=False):
required=required,
type=regular_file(),
help='Source side of parallel training data.')
params.add_argument('--source-factors', '-sf',
required=False,
nargs='+',
type=regular_file(),
default=[],
help='File(s) containing additional token-parallel source side factors. Default: %(default)s.')
params.add_argument(C.TRAINING_ARG_TARGET, '-t',
required=required,
type=regular_file(),
Expand All @@ -267,6 +273,13 @@ def add_validation_data_params(params):
required=True,
type=regular_file(),
help='Source side of validation data.')
params.add_argument('--validation-source-factors', '-vsf',
required=False,
nargs='+',
type=regular_file(),
default=[],
help='File(s) containing additional token-parallel validation source side factors. '
'Default: %(default)s.')
params.add_argument('--validation-target', '-vt',
required=True,
type=regular_file(),
Expand Down Expand Up @@ -426,7 +439,7 @@ def add_model_parameters(params):
model_params.add_argument('--allow-missing-params',
action="store_true",
default=False,
help="Allow misssing parameters when initializing model parameters from file. "
help="Allow missing parameters when initializing model parameters from file. "
"Default: %(default)s.")

model_params.add_argument('--encoder',
Expand Down Expand Up @@ -576,6 +589,13 @@ def add_model_parameters(params):
default=(512, 512),
help='Embedding size for source and target tokens. '
'Use "x:x" to specify separate values for src&tgt. Default: %(default)s.')
model_params.add_argument('--source-factors-num-embed',
type=int,
nargs='+',
default=[],
help='Embedding size for additional source factors. '
'You must provide as many dimensions as '
'(validation) source factor files. Default: %(default)s.')

# attention arguments
model_params.add_argument('--rnn-attention-type',
Expand Down Expand Up @@ -924,6 +944,23 @@ def add_inference_args(params):
help='Input file to translate. One sentence per line. '
'If not given, will read from stdin.')

decode_params.add_argument(C.INFERENCE_ARG_INPUT_FACTORS_LONG, C.INFERENCE_ARG_INPUT_FACTORS_SHORT,
required=False,
nargs='+',
type=regular_file(),
default=None,
help='List of input files containing additional source factors,'
'each token-parallel to the source. Default: %(default)s.')

decode_params.add_argument('--json-input',
action='store_true',
default=False,
help="If given, the CLI expects string-serialized json objects as input."
"Requires at least the input text field, for example: "
"{'text': 'some input string'} "
"Optionally, a list of factors can be provided: "
"{'text': 'some input string', 'factors': ['C C C', 'X X X']}.")

decode_params.add_argument(C.INFERENCE_ARG_OUTPUT_LONG, C.INFERENCE_ARG_OUTPUT_SHORT,
default=None,
help='Output file to write translations to. '
Expand Down
90 changes: 56 additions & 34 deletions sockeye/checkpoint_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import os
import random
import time
from typing import Dict, Optional
from contextlib import ExitStack
from typing import Any, Dict, Iterable, Optional, List

import mxnet as mx

import sockeye.output_handler
import sockeye.translate
from . import evaluate
from . import constants as C
from . import data_io
Expand All @@ -37,7 +39,7 @@ class CheckpointDecoder:
Decodes a (random sample of a) dataset using parameters at given checkpoint and computes BLEU against references.
:param context: MXNet context to bind the model to.
:param inputs: Path to file containing input sentences.
:param inputs: Path(s) to file containing input sentences (and their factors).
:param references: Path to file containing references.
:param model: Model to load.
:param max_input_len: Maximum input length.
Expand All @@ -55,7 +57,7 @@ class CheckpointDecoder:

def __init__(self,
context: mx.context.Context,
inputs: str,
inputs: List[str],
references: str,
model: str,
max_input_len: Optional[int] = None,
Expand All @@ -79,30 +81,33 @@ def __init__(self,
self.length_penalty_beta = length_penalty_beta
self.softmax_temperature = softmax_temperature
self.model = model
with data_io.smart_open(inputs) as inputs_fin, data_io.smart_open(references) as references_fin:
input_sentences = inputs_fin.readlines()

with ExitStack() as exit_stack:
inputs_fins = [exit_stack.enter_context(data_io.smart_open(f)) for f in inputs]
references_fin = exit_stack.enter_context(data_io.smart_open(references))

inputs_sentences = [f.readlines() for f in inputs_fins]
target_sentences = references_fin.readlines()
utils.check_condition(len(input_sentences) == len(target_sentences), "Number of sentence pairs do not match")

utils.check_condition(all(len(l) == len(target_sentences) for l in inputs_sentences),
"Sentences differ in length")

if sample_size <= 0:
sample_size = len(input_sentences)
if sample_size < len(input_sentences):
# custom random number generator to guarantee the same samples across runs in order to be able to
# compare metrics across independent runs
random_gen = random.Random(random_seed)
self.input_sentences, self.target_sentences = zip(
*random_gen.sample(list(zip(input_sentences, target_sentences)),
sample_size))
sample_size = len(inputs_sentences[0])
if sample_size < len(inputs_sentences[0]):
self.target_sentences, *self.inputs_sentences = parallel_subsample(
[target_sentences] + inputs_sentences, sample_size, random_seed)
else:
self.input_sentences, self.target_sentences = input_sentences, target_sentences
self.inputs_sentences, self.target_sentences = inputs_sentences, target_sentences

logger.info("Created CheckpointDecoder(max_input_len=%d, beam_size=%d, model=%s, num_sentences=%d)",
max_input_len if max_input_len is not None else -1,
beam_size, model, len(self.input_sentences))
for i, factor in enumerate(self.inputs_sentences):
write_to_file(factor, os.path.join(self.model, C.DECODE_IN_NAME % i))
write_to_file(self.target_sentences, os.path.join(self.model, C.DECODE_REF_NAME))

with data_io.smart_open(os.path.join(self.model, C.DECODE_REF_NAME), 'w') as trg_out, \
data_io.smart_open(os.path.join(self.model, C.DECODE_IN_NAME), 'w') as src_out:
[trg_out.write(s) for s in self.target_sentences]
[src_out.write(s) for s in self.input_sentences]
self.inputs_sentences = list(zip(*self.inputs_sentences)) # type: List[List[str]]

logger.info("Created CheckpointDecoder(max_input_len=%d, beam_size=%d, model=%s, num_sentences=%d)",
max_input_len if max_input_len is not None else -1, beam_size, model, len(self.target_sentences))

def decode_and_evaluate(self,
checkpoint: Optional[int] = None,
Expand All @@ -114,33 +119,36 @@ def decode_and_evaluate(self,
:param output_name: Filename to write translations to. Defaults to /dev/null.
:return: Mapping of metric names to scores.
"""
models, vocab_source, vocab_target = inference.load_models(self.context,
self.max_input_len,
self.beam_size,
self.batch_size,
[self.model],
[checkpoint],
softmax_temperature=self.softmax_temperature,
max_output_length_num_stds=self.max_output_length_num_stds)
models, source_vocabs, target_vocab = inference.load_models(
self.context,
self.max_input_len,
self.beam_size,
self.batch_size,
[self.model],
[checkpoint],
softmax_temperature=self.softmax_temperature,
max_output_length_num_stds=self.max_output_length_num_stds)
translator = inference.Translator(self.context,
self.ensemble_mode,
self.bucket_width_source,
inference.LengthPenalty(self.length_penalty_alpha, self.length_penalty_beta),
models,
vocab_source,
vocab_target)
source_vocabs,
target_vocab)
trans_wall_time = 0.0
translations = []
with data_io.smart_open(output_name, 'w') as output:
handler = sockeye.output_handler.StringOutputHandler(output)
tic = time.time()
trans_inputs = [translator.make_input(i, line) for i, line in enumerate(self.input_sentences)]
trans_inputs = [] # type: List[inference.TranslatorInput]
for i, inputs in enumerate(self.inputs_sentences):
trans_inputs.append(sockeye.inference.make_input_from_multiple_strings(i, inputs))
trans_outputs = translator.translate(trans_inputs)
trans_wall_time = time.time() - tic
for trans_input, trans_output in zip(trans_inputs, trans_outputs):
handler.handle(trans_input, trans_output)
translations.append(trans_output.translation)
avg_time = trans_wall_time / len(self.input_sentences)
avg_time = trans_wall_time / len(self.target_sentences)

# TODO(fhieber): eventually add more metrics (METEOR etc.)
return {C.BLEU_VAL: evaluate.raw_corpus_bleu(hypotheses=translations,
Expand All @@ -149,3 +157,17 @@ def decode_and_evaluate(self,
C.CHRF_VAL: evaluate.raw_corpus_chrf(hypotheses=translations,
references=self.target_sentences),
C.AVG_TIME: avg_time}


def parallel_subsample(parallel_sequences: List[List[Any]], sample_size: int, seed: int) -> List[Any]:
# custom random number generator to guarantee the same samples across runs in order to be able to
# compare metrics across independent runs
random_gen = random.Random(seed)
parallel_sample = list(zip(*random_gen.sample(list(zip(*parallel_sequences)), sample_size)))
return parallel_sample


def write_to_file(data: List[str], fname: str):
with data_io.smart_open(fname, 'w') as f:
for x in data:
print(x.rstrip(), file=f)
19 changes: 15 additions & 4 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,24 @@
CHUNK_SIZE_NO_BATCHING = 1
CHUNK_SIZE_PER_BATCH_SEGMENT = 500

# Inference Input JSON constants
JSON_TEXT_KEY = "text"
JSON_FACTORS_KEY = "factors"
JSON_ENCODING = "utf-8"

VERSION_NAME = "version"
CONFIG_NAME = "config"
LOG_NAME = "log"
JSON_SUFFIX = ".json"
VOCAB_SRC_NAME = "vocab.src"
VOCAB_TRG_NAME = "vocab.trg"
VOCAB_SRC_PREFIX = "vocab.src"
VOCAB_SRC_NAME = VOCAB_SRC_PREFIX + ".%d" + JSON_SUFFIX
VOCAB_TRG_NAME = "vocab.trg" + JSON_SUFFIX
VOCAB_ENCODING = "utf-8"
PARAMS_PREFIX = "params."
PARAMS_NAME = PARAMS_PREFIX + "%05d"
PARAMS_BEST_NAME = "params.best"
DECODE_OUT_NAME = "decode.output.%05d"
DECODE_IN_NAME = "decode.source"
DECODE_IN_NAME = "decode.source.%d"
DECODE_REF_NAME = "decode.target"
SYMBOL_NAME = "symbol" + JSON_SUFFIX
METRICS_NAME = "metrics"
Expand Down Expand Up @@ -232,9 +238,13 @@
INFERENCE_ARG_INPUT_SHORT = "-i"
INFERENCE_ARG_OUTPUT_LONG = "--output"
INFERENCE_ARG_OUTPUT_SHORT = "-o"
INFERENCE_ARG_INPUT_FACTORS_LONG = "--input-factors"
INFERENCE_ARG_INPUT_FACTORS_SHORT = "-if"
TRAIN_ARGS_MONITOR_BLEU = "--decode-and-evaluate"
TRAIN_ARGS_CHECKPOINT_FREQUENCY = "--checkpoint-frequency"

# Used to delimit factors on STDIN for inference
DEFAULT_FACTOR_DELIMITER = '|'

# data layout strings
BATCH_MAJOR = "NTC"
Expand Down Expand Up @@ -334,7 +344,8 @@
SHARD_NAME = "shard.%05d"
SHARD_SOURCE = SHARD_NAME + ".source"
SHARD_TARGET = SHARD_NAME + ".target"
DATA_INFO = "data.info"
DATA_CONFIG = "data.config"
PREPARED_DATA_VERSION_FILE = "data.version"
PREPARED_DATA_VERSION = 1
PREPARED_DATA_VERSION = 2

Loading

0 comments on commit 5dcb60d

Please sign in to comment.