From f68a2174215796254d1c2cac7bce23aecc779bad Mon Sep 17 00:00:00 2001 From: Anna Currey Date: Thu, 27 Aug 2020 09:30:01 -0400 Subject: [PATCH] Add ability to cache encoder outputs of model (#858) --- CHANGELOG.md | 6 ++++++ sockeye/__init__.py | 2 +- sockeye/model.py | 49 +++++++++++++++++++++++++++++++++++++-------- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 912575625..ab796d5c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [2.1.21] + +### Added + +- Added an optional ability to cache encoder outputs of model. + ## [2.1.20] ### Fixed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index f8cfe1137..861dea7f0 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '2.1.20' +__version__ = '2.1.21' diff --git a/sockeye/model.py b/sockeye/model.py index bcd2fd157..9cde10ba6 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -16,6 +16,7 @@ import logging import os from typing import cast, Dict, Optional, Tuple, Union, List +from functools import lru_cache import mxnet as mx from sockeye import __version__ @@ -104,6 +105,7 @@ def __init__(self, config: ModelConfig, inference_only: bool = False, mc_dropout: bool = False, + forward_pass_cache_size: int = 0, prefix: str = '', **kwargs) -> None: super().__init__(prefix=prefix, **kwargs) @@ -111,6 +113,10 @@ def __init__(self, logger.info("%s", self.config) self.dtype = config.dtype self.mc_dropout = mc_dropout + self.forward_pass_cache_size = forward_pass_cache_size + self.embed_and_encode = self._embed_and_encode + if self.forward_pass_cache_size > 0: + self.embed_and_encode = self._cache_wrapper(self._embed_and_encode) with self.name_scope(): # source & target embeddings @@ -198,6 +204,23 @@ def encode_and_initialize(self, inputs, valid_length=None, constant_length_ratio return states, predicted_output_length + def _embed_and_encode(self, source, source_length, target, target_length): + """ + Encode the input sequence, embed the target sequence, and initialize the decoder. + Used for training. + + :param source: Source input data. + :param source_length: Length of source inputs. + :param target: Target input data. + :param target_length: Length of target inputs. + :return: encoder outputs and lengths, target embeddings, and decoder initial states + """ + source_embed, source_embed_length = self.embedding_source(source, source_length) + target_embed, target_embed_length = self.embedding_target(target, target_length) + source_encoded, source_encoded_length = self.encoder(source_embed, source_embed_length) + states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_length) + return source_encoded, source_encoded_length, target_embed, states + def decode_step(self, step_input, states, vocab_slice_ids=None): """ One step decoding of the translation model. @@ -237,11 +260,9 @@ def decode_step(self, step_input, states, vocab_slice_ids=None): return step_output, new_states, step_additional_outputs def forward(self, source, source_length, target, target_length): # pylint: disable=arguments-differ - source_embed, source_embed_length = self.embedding_source(source, source_length) - target_embed, target_embed_length = self.embedding_target(target, target_length) - source_encoded, source_encoded_length = self.encoder(source_embed, source_embed_length) + source_encoded, source_encoded_length, target_embed, states = self.embed_and_encode(source, source_length, + target, target_length) - states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_length) target = self.decoder.decode_seq(target_embed, states=states) output = self.output_layer(target, None) @@ -460,6 +481,12 @@ def length_ratio_std(self) -> float: def output_layer_vocab_size(self) -> int: return self.output_layer.vocab_size + def _cache_wrapper(self, class_func): + @lru_cache(maxsize=self.forward_pass_cache_size) + def cache_func(*args): + return class_func(*args) + return cache_func + def load_model(model_folder: str, context: Union[List[mx.context.Context], mx.context.Context] = mx.cpu(), @@ -470,7 +497,8 @@ def load_model(model_folder: str, mc_dropout: bool = False, for_disk_saving: Optional[str] = None, allow_missing: bool = False, - set_grad_req_null: bool = True) -> Tuple[SockeyeModel, List[vocab.Vocab], vocab.Vocab]: + set_grad_req_null: bool = True, + forward_pass_cache_size: int = 0) -> Tuple[SockeyeModel, List[vocab.Vocab], vocab.Vocab]: """ Load a model from model_folder. @@ -489,6 +517,7 @@ def load_model(model_folder: str, for writing to disk as float32 with precomputed scaling factors. :param allow_missing: Allow missing parameters in the loaded model. :param set_grad_req_null: Set grad_req to null for model parameters. + :param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass. :return: List of models, source vocabularies, target vocabulary. """ source_vocabs = vocab.load_source_vocabs(model_folder) @@ -527,7 +556,8 @@ def load_model(model_folder: str, if quantizing: model_config.dtype = C.DTYPE_INT8 # Ensure the scaling factor parameters are created. - model = SockeyeModel(model_config, inference_only=inference_only, mc_dropout=mc_dropout) + model = SockeyeModel(model_config, inference_only=inference_only, mc_dropout=mc_dropout, + forward_pass_cache_size=forward_pass_cache_size) model.initialize(ctx=context) if model_config.dtype != C.DTYPE_INT8: # If model_config.dtype is int8, then the above model construction @@ -593,7 +623,8 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context], inference_only: bool = False, mc_dropout: bool = False, allow_missing: bool = False, - set_grad_req_null: bool = True) -> Tuple[List[SockeyeModel], List[vocab.Vocab], vocab.Vocab]: + set_grad_req_null: bool = True, + forward_pass_cache_size: int = 0) -> Tuple[List[SockeyeModel], List[vocab.Vocab], vocab.Vocab]: """ Loads a list of models for inference. @@ -606,6 +637,7 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context], :param mc_dropout: Turn on dropout during inference. :param allow_missing: Allow missing parameters in the loaded models. :param set_grad_req_null: Set grad_req to null for model parameters. + :param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass. :return: List of models, source vocabulary, target vocabulary, source factor vocabularies. """ logger.info("Loading %d model(s) from %s ...", len(model_folders), model_folders) @@ -628,7 +660,8 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context], inference_only=inference_only, mc_dropout=mc_dropout, allow_missing=allow_missing, - set_grad_req_null=set_grad_req_null) + set_grad_req_null=set_grad_req_null, + forward_pass_cache_size=forward_pass_cache_size) models.append(model) source_vocabs.append(src_vcbs) target_vocabs.append(trg_vcb)