Skip to content

Commit

Permalink
Add ability to cache encoder outputs of model (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
annacurrey authored Aug 27, 2020
1 parent 7a912fe commit f68a217
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [2.1.21]

### Added

- Added an optional ability to cache encoder outputs of model.

## [2.1.20]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.1.20'
__version__ = '2.1.21'
49 changes: 41 additions & 8 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
from typing import cast, Dict, Optional, Tuple, Union, List
from functools import lru_cache

import mxnet as mx
from sockeye import __version__
Expand Down Expand Up @@ -104,13 +105,18 @@ def __init__(self,
config: ModelConfig,
inference_only: bool = False,
mc_dropout: bool = False,
forward_pass_cache_size: int = 0,
prefix: str = '',
**kwargs) -> None:
super().__init__(prefix=prefix, **kwargs)
self.config = copy.deepcopy(config)
logger.info("%s", self.config)
self.dtype = config.dtype
self.mc_dropout = mc_dropout
self.forward_pass_cache_size = forward_pass_cache_size
self.embed_and_encode = self._embed_and_encode
if self.forward_pass_cache_size > 0:
self.embed_and_encode = self._cache_wrapper(self._embed_and_encode)

with self.name_scope():
# source & target embeddings
Expand Down Expand Up @@ -198,6 +204,23 @@ def encode_and_initialize(self, inputs, valid_length=None, constant_length_ratio

return states, predicted_output_length

def _embed_and_encode(self, source, source_length, target, target_length):
"""
Encode the input sequence, embed the target sequence, and initialize the decoder.
Used for training.
:param source: Source input data.
:param source_length: Length of source inputs.
:param target: Target input data.
:param target_length: Length of target inputs.
:return: encoder outputs and lengths, target embeddings, and decoder initial states
"""
source_embed, source_embed_length = self.embedding_source(source, source_length)
target_embed, target_embed_length = self.embedding_target(target, target_length)
source_encoded, source_encoded_length = self.encoder(source_embed, source_embed_length)
states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_length)
return source_encoded, source_encoded_length, target_embed, states

def decode_step(self, step_input, states, vocab_slice_ids=None):
"""
One step decoding of the translation model.
Expand Down Expand Up @@ -237,11 +260,9 @@ def decode_step(self, step_input, states, vocab_slice_ids=None):
return step_output, new_states, step_additional_outputs

def forward(self, source, source_length, target, target_length): # pylint: disable=arguments-differ
source_embed, source_embed_length = self.embedding_source(source, source_length)
target_embed, target_embed_length = self.embedding_target(target, target_length)
source_encoded, source_encoded_length = self.encoder(source_embed, source_embed_length)
source_encoded, source_encoded_length, target_embed, states = self.embed_and_encode(source, source_length,
target, target_length)

states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_length)
target = self.decoder.decode_seq(target_embed, states=states)

output = self.output_layer(target, None)
Expand Down Expand Up @@ -460,6 +481,12 @@ def length_ratio_std(self) -> float:
def output_layer_vocab_size(self) -> int:
return self.output_layer.vocab_size

def _cache_wrapper(self, class_func):
@lru_cache(maxsize=self.forward_pass_cache_size)
def cache_func(*args):
return class_func(*args)
return cache_func


def load_model(model_folder: str,
context: Union[List[mx.context.Context], mx.context.Context] = mx.cpu(),
Expand All @@ -470,7 +497,8 @@ def load_model(model_folder: str,
mc_dropout: bool = False,
for_disk_saving: Optional[str] = None,
allow_missing: bool = False,
set_grad_req_null: bool = True) -> Tuple[SockeyeModel, List[vocab.Vocab], vocab.Vocab]:
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0) -> Tuple[SockeyeModel, List[vocab.Vocab], vocab.Vocab]:
"""
Load a model from model_folder.
Expand All @@ -489,6 +517,7 @@ def load_model(model_folder: str,
for writing to disk as float32 with precomputed scaling factors.
:param allow_missing: Allow missing parameters in the loaded model.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:return: List of models, source vocabularies, target vocabulary.
"""
source_vocabs = vocab.load_source_vocabs(model_folder)
Expand Down Expand Up @@ -527,7 +556,8 @@ def load_model(model_folder: str,
if quantizing:
model_config.dtype = C.DTYPE_INT8 # Ensure the scaling factor parameters are created.

model = SockeyeModel(model_config, inference_only=inference_only, mc_dropout=mc_dropout)
model = SockeyeModel(model_config, inference_only=inference_only, mc_dropout=mc_dropout,
forward_pass_cache_size=forward_pass_cache_size)
model.initialize(ctx=context)
if model_config.dtype != C.DTYPE_INT8:
# If model_config.dtype is int8, then the above model construction
Expand Down Expand Up @@ -593,7 +623,8 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context],
inference_only: bool = False,
mc_dropout: bool = False,
allow_missing: bool = False,
set_grad_req_null: bool = True) -> Tuple[List[SockeyeModel], List[vocab.Vocab], vocab.Vocab]:
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0) -> Tuple[List[SockeyeModel], List[vocab.Vocab], vocab.Vocab]:
"""
Loads a list of models for inference.
Expand All @@ -606,6 +637,7 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context],
:param mc_dropout: Turn on dropout during inference.
:param allow_missing: Allow missing parameters in the loaded models.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:return: List of models, source vocabulary, target vocabulary, source factor vocabularies.
"""
logger.info("Loading %d model(s) from %s ...", len(model_folders), model_folders)
Expand All @@ -628,7 +660,8 @@ def load_models(context: Union[List[mx.context.Context], mx.context.Context],
inference_only=inference_only,
mc_dropout=mc_dropout,
allow_missing=allow_missing,
set_grad_req_null=set_grad_req_null)
set_grad_req_null=set_grad_req_null,
forward_pass_cache_size=forward_pass_cache_size)
models.append(model)
source_vocabs.append(src_vcbs)
target_vocabs.append(trg_vcb)
Expand Down

0 comments on commit f68a217

Please sign in to comment.