From b24b2c1352e71659fd61e49f9384f255e4161e5a Mon Sep 17 00:00:00 2001 From: Michael Denkowski Date: Wed, 23 Mar 2022 03:35:17 -0500 Subject: [PATCH] Also trace SockeyeModel components when `inference_only == False` (includes CheckpointDecoder) (#1032) * Trace checkpoint decoder - Remove inference_only checks for model tracing - Checkpoint decoder always runs in eval mode * Version and changelog * Grammar * Whitespace * Rename variable --- CHANGELOG.md | 6 ++++ sockeye/__init__.py | 2 +- sockeye/checkpoint_decoder.py | 20 ++++++++++-- sockeye/model.py | 60 ++++++++++++++--------------------- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ccbe252c..f85c2942c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.7] + +### Changed + +- SockeyeModel components are now traced regardless of whether `inference_only` is set, including for the CheckpointDecoder during training. + ## [3.1.6] ### Changed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 4db7f5634..558bf30c2 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.6' +__version__ = '3.1.7' diff --git a/sockeye/checkpoint_decoder.py b/sockeye/checkpoint_decoder.py index 094c0e86a..42bfe8c22 100644 --- a/sockeye/checkpoint_decoder.py +++ b/sockeye/checkpoint_decoder.py @@ -85,7 +85,6 @@ def __init__(self, self.bucket_width_source = bucket_width_source self.length_penalty_alpha = length_penalty_alpha self.length_penalty_beta = length_penalty_beta - # TODO(mdenkows): Trace encoder/decoder even though inference_only=False self.model = model with ExitStack() as exit_stack: @@ -149,6 +148,12 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl """ # 1. Translate + + # Store original mode and set to eval mode in case the model is not yet + # traced. + original_mode = self.model.training + self.model.eval() + trans_wall_time = 0.0 translations = [] # type: List[List[str]] with ExitStack() as exit_stack: @@ -172,7 +177,11 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl avg_time = trans_wall_time / len(self.targets_sentences[0]) translations = list(zip(*translations)) # type: ignore + # Restore original model mode + self.model.train(original_mode) + # 2. Evaluate + metrics = {C.BLEU: evaluate.raw_corpus_bleu(hypotheses=translations[0], references=self.targets_sentences[0], offset=0.01), @@ -202,9 +211,16 @@ def decode_and_evaluate(self, output_name: Optional[str] = None) -> Dict[str, fl return metrics def warmup(self): - """Translate a single sentence to warm up the model""" + """ + Translate a single sentence to warm up the model. Set the model to eval + mode for tracing, translate the sentence, then set the model back to its + original mode. + """ + original_mode = self.model.training + self.model.eval() one_sentence = [inference.make_input_from_multiple_strings(0, self.inputs_sentences[0])] _ = self.translator.translate(one_sentence) + self.model.train(original_mode) def parallel_subsample(parallel_sequences: List[List[Any]], sample_size: int, seed: int) -> List[Any]: diff --git a/sockeye/model.py b/sockeye/model.py index 8b81902bc..e595cd79f 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -116,6 +116,8 @@ def __init__(self, vocab_size=self.config.vocab_target_size, weight=output_weight) if self.inference_only: + # Running this layer scripted with a newly initialized model can + # cause an overflow error. self.output_layer = pt.jit.script(self.output_layer) self.factor_output_layers = pt.nn.ModuleList() @@ -167,19 +169,14 @@ def encode(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None) -> :param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size,) :return: Encoder outputs, encoded output lengths """ - - if self.inference_only: - if self.traced_embedding_source is None: - logger.debug("Tracing embedding_source") - self.traced_embedding_source = pt.jit.trace(self.embedding_source, inputs) - source_embed = self.traced_embedding_source(inputs) - if self.traced_encoder is None: - logger.debug("Tracing encoder") - self.traced_encoder = pt.jit.trace(self.encoder, (source_embed, valid_length)) - source_encoded, source_encoded_length = self.traced_encoder(source_embed, valid_length) - else: - source_embed = self.embedding_source(inputs) - source_encoded, source_encoded_length = self.encoder(source_embed, valid_length) + if self.traced_embedding_source is None: + logger.debug("Tracing embedding_source") + self.traced_embedding_source = pt.jit.trace(self.embedding_source, inputs) + source_embed = self.traced_embedding_source(inputs) + if self.traced_encoder is None: + logger.debug("Tracing encoder") + self.traced_encoder = pt.jit.trace(self.encoder, (source_embed, valid_length)) + source_encoded, source_encoded_length = self.traced_encoder(source_embed, valid_length) return source_encoded, source_encoded_length def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None, @@ -238,29 +235,20 @@ def decode_step(self, :return: logits, list of new model states, other target factor logits. """ - if self.inference_only: - decode_step_inputs = [step_input, states] - if vocab_slice_ids is not None: - decode_step_inputs.append(vocab_slice_ids) - if self.traced_decode_step is None: - logger.debug("Tracing decode step") - decode_step_module = _DecodeStep(self.embedding_target, - self.decoder, - self.output_layer, - self.factor_output_layers) - self.traced_decode_step = pt.jit.trace(decode_step_module, decode_step_inputs) - # the traced module returns a flat list of tensors - decode_step_outputs = self.traced_decode_step(*decode_step_inputs) - step_output, *target_factor_outputs = decode_step_outputs[:self.num_target_factors] - new_states = decode_step_outputs[self.num_target_factors:] - else: - target_embed = self.embedding_target(step_input.unsqueeze(1)) - decoder_out, new_states = self.decoder(target_embed, states) - decoder_out = decoder_out.squeeze(1) - # step_output: (batch_size, target_vocab_size or vocab_slice_ids) - step_output = self.output_layer(decoder_out, vocab_slice_ids) - target_factor_outputs = [fol(decoder_out) for fol in self.factor_output_layers] - + decode_step_inputs = [step_input, states] + if vocab_slice_ids is not None: + decode_step_inputs.append(vocab_slice_ids) + if self.traced_decode_step is None: + logger.debug("Tracing decode step") + decode_step_module = _DecodeStep(self.embedding_target, + self.decoder, + self.output_layer, + self.factor_output_layers) + self.traced_decode_step = pt.jit.trace(decode_step_module, decode_step_inputs) + # the traced module returns a flat list of tensors + decode_step_outputs = self.traced_decode_step(*decode_step_inputs) + step_output, *target_factor_outputs = decode_step_outputs[:self.num_target_factors] + new_states = decode_step_outputs[self.num_target_factors:] return step_output, new_states, target_factor_outputs def forward(self, source, source_length, target, target_length): # pylint: disable=arguments-differ