Skip to content

Commit

Permalink
Add SSRU layer and decoder (#851)
Browse files Browse the repository at this point in the history
- Adds layers.SSRU, which implements a Simpler Simple Recurrent Unit as described by Kim et al, 2019.
- Adds ssru_transformer option to --decoder, which enables the usage of SSRUs as a replacement for the decoder-side self-attention layers.
- Reduces the number of arguments for MultiHeadSelfAttention.hybrid_forward().
previous_keys and previous_values should now be input together as previous_states, a list containing two symbols.
  • Loading branch information
barbaradarques authored Aug 20, 2020
1 parent 2cbad61 commit 92a020a
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 84 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [2.1.17]

### Added

- Added `layers.SSRU`, which implements a Simpler Simple Recurrent Unit as described in
Kim et al, "From Research to Production and Back: Ludicrously Fast Neural Machine Translation" WNGT 2019.

- Added `ssru_transformer` option to `--decoder`, which enables the usage of SSRUs as a replacement for the decoder-side self-attention layers.

### Changed

- Reduced the number of arguments for `MultiHeadSelfAttention.hybrid_forward()`.
`previous_keys` and `previous_values` should now be input together as `previous_states`, a list containing two symbols.


## [2.1.16]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.1.16'
__version__ = '2.1.17'
4 changes: 3 additions & 1 deletion sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def add_model_parameters(params):
model_params.add_argument('--decoder',
choices=C.DECODERS,
default=C.TRANSFORMER_TYPE,
help="Type of encoder. Default: %(default)s.")
help="Type of decoder. Default: %(default)s. "
"'ssru_transformer' uses Simpler Simple Recurrent Units (Kim et al, 2019) "
"as replacement for self-attention layers.")

model_params.add_argument('--num-layers',
type=multiple_values(num_values=2, greater_or_equal=1),
Expand Down
10 changes: 8 additions & 2 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
DEFAULT_OUTPUT_LAYER_PREFIX = "target_output_"
LENRATIOS_OUTPUT_LAYER_PREFIX = "length_ratio_layer_"

# SSRU
SSRU_PREFIX = "ssru_"

# embedding prefixes
SOURCE_EMBEDDING_PREFIX = "source_" + EMBEDDING_PREFIX
SOURCE_POSITIONAL_EMBEDDING_PREFIX = "source_pos_" + EMBEDDING_PREFIX
Expand All @@ -72,8 +75,10 @@
# available encoders
ENCODERS = [TRANSFORMER_TYPE]

# available decoder
DECODERS = [TRANSFORMER_TYPE]
# TODO replace options list (e.g ENCODERS, DECODERS, ...) with Enum classes
# available decoders
SSRU_TRANSFORMER = SSRU_PREFIX + TRANSFORMER_TYPE
DECODERS = [TRANSFORMER_TYPE, SSRU_TRANSFORMER]

# positional embeddings
NO_POSITIONAL_EMBEDDING = "none"
Expand Down Expand Up @@ -116,6 +121,7 @@

# default decoder prefixes
TRANSFORMER_DECODER_PREFIX = DECODER_PREFIX + "transformer_"
TRANSFORMER_SSRU_DECODER_PREFIX = DECODER_PREFIX + SSRU_TRANSFORMER

# Activation types
RELU = "relu"
Expand Down
81 changes: 42 additions & 39 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import logging
from abc import abstractmethod
from itertools import islice
from typing import Dict, List, Optional, Tuple, Union, Type

import mxnet as mx
Expand Down Expand Up @@ -66,7 +67,7 @@ def get_decoder(cls, config: DecoderConfig, inference_only: bool, prefix: str, d
Creates decoder based on config type.
:param config: Decoder config.
:param inference_ony: Create a decoder that is only used for inference.
:param inference_only: Create a decoder that is only used for inference.
:param prefix: Prefix to prepend for decoder.
:param dtype: Data type for weights.
Expand Down Expand Up @@ -114,7 +115,7 @@ def get_num_hidden(self):
class TransformerDecoder(Decoder, mx.gluon.HybridBlock):
"""
Transformer decoder as in Vaswani et al, 2017: Attention is all you need.
In training, computation scores for each position of the known target sequence are compouted in parallel,
In training, computation scores for each position of the known target sequence are computed in parallel,
yielding most of the speedup.
At inference time, the decoder block is evaluated again and again over a maximum length input sequence that is
initially filled with zeros and grows during beam search with predicted tokens. Appropriate masking at every
Expand Down Expand Up @@ -147,7 +148,8 @@ def __init__(self,
name="bias")
self.layers = mx.gluon.nn.HybridSequential()
for i in range(config.num_layers):
self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype))
self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype,
inference_only=self.inference_only))

self.final_process = transformer.TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
Expand All @@ -164,7 +166,9 @@ def state_structure(self) -> str:
structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers * 2
else:
structure += C.STEP_STATE + C.ENCODER_STATE + C.BIAS_STATE
structure += C.DECODER_STATE * self.config.num_layers * 2

total_num_states = sum(layer.num_state_tensors for layer in self.layers)
structure += C.DECODER_STATE * total_num_states

return structure

Expand All @@ -177,7 +181,7 @@ def init_state_from_encoder(self,
At inference, this method returns the following state tuple:
valid length bias, step state,
[projected encoder attention keys, projected encoder attention values] * num_layers,
[self attention dummies] * num_layers.
[autoregressive state dummies] * num_layers.
:param encoder_outputs: Encoder outputs. Shape: (batch, source_length, encoder_dim).
:param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,).
Expand All @@ -202,14 +206,13 @@ def init_state_from_encoder(self,
states = [step, encoder_outputs, source_mask]

batch_size = encoder_outputs.shape[0]
# shape: (batch, heads, length, depth_per_head)
self_att_key_value_dummies = [mx.nd.zeros((batch_size,
self.config.attention_heads,
1,
self.config.model_size // self.config.attention_heads),
ctx=encoder_outputs.context,
dtype=encoder_outputs.dtype)] * self.config.num_layers * 2
states += self_att_key_value_dummies
dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size),
ctx=encoder_outputs.context,
dtype=encoder_outputs.dtype)
for layer in self.layers
for _ in range(layer.num_state_tensors)]

states += dummy_autoregr_states

return states

Expand Down Expand Up @@ -256,7 +259,7 @@ def forward(self, step_input, states):
states = [steps] + states

# run decoder op
target, self_attention_key_values = super().forward(step_input, states)
target, autoregr_states = super().forward(step_input, states)

if is_inference:
# During inference, length dimension of decoder output has size 1, squeeze it
Expand All @@ -269,39 +272,36 @@ def forward(self, step_input, states):
if self.inference_only:
# pass in cached encoder states
encoder_attention_keys_values = states[2:2 + self.config.num_layers * 2]
new_states = [step, states[1]] + encoder_attention_keys_values + self_attention_key_values
new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states
else:
encoder_outputs = states[1]
source_mask = states[2]
new_states = [step, encoder_outputs, source_mask] + self_attention_key_values
new_states = [step, encoder_outputs, source_mask] + autoregr_states

assert len(new_states) == len(states)
else:
new_states = None # we don't care about states in training
return target, new_states

def hybrid_forward(self, F, step_input, states):
mask = None
if self.inference_only:
# No autoregressive mask needed for decoding
mask = None

steps, source_mask, *other = states

source_encoded = None # use constant pre-computed key value projections from the states
enc_att_kv = other[:self.config.num_layers * 2]
enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)]
self_att_kv = other[self.config.num_layers * 2:]
self_att_kv = [self_att_kv[i:i + 2] for i in range(0, len(self_att_kv), 2)]
autoregr_states = other[self.config.num_layers * 2:]
else:
mask = self.autoregressive_bias(step_input) # mask: (1, length, length)

steps, source_encoded, source_mask, *other = states

self_att_kv = other
self_att_kv = [self_att_kv[i:i + 2] for i in range(0, len(self_att_kv), 2)]

if any(layer.needs_mask for layer in self.layers):
mask = self.autoregressive_bias(step_input) # mask: (1, length, length)
steps, source_encoded, source_mask, *autoregr_states = states
enc_att_kv = [(None, None) for _ in range(self.config.num_layers)]

if any(layer.num_state_tensors > 1 for layer in self.layers):
# separates autoregressive states by layer
states_iter = iter(autoregr_states)
autoregr_states = [list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers]

# Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, 1, seq_len)
source_mask = F.expand_dims(F.reshape(source_mask, shape=(-3, -2)), axis=1)

Expand All @@ -311,18 +311,21 @@ def hybrid_forward(self, F, step_input, states):
if self.config.dropout_prepost > 0.0:
target = F.Dropout(data=target, p=self.config.dropout_prepost)

new_self_att_kv = [] # type: List[Tuple]
for layer, (self_att_k, self_att_v), (enc_att_k, enc_att_v) in zip(self.layers, self_att_kv, enc_att_kv):
target, new_self_att_k, new_self_att_v = layer(target,
mask,
source_encoded,
source_mask,
self_att_k, self_att_v,
enc_att_k, enc_att_v)
new_self_att_kv += [new_self_att_k, new_self_att_v]
new_autoregr_states = []
for layer, layer_autoregr_state, (enc_att_k, enc_att_v) in zip(self.layers, autoregr_states, enc_att_kv):
target, new_layer_autoregr_state = layer(target,
mask,
source_encoded,
source_mask,
layer_autoregr_state,
enc_att_k, enc_att_v)

new_autoregr_states += [*new_layer_autoregr_state]
# NOTE: the list expansion is needed in order to handle both a tuple (of Symbols) and a Symbol as a new state

target = self.final_process(target, None)

return target, new_self_att_kv
return target, new_autoregr_states

def get_num_hidden(self):
return self.config.model_size
Loading

0 comments on commit 92a020a

Please sign in to comment.