From 92a020a25cbe75935c700ce2f29b286b31a87189 Mon Sep 17 00:00:00 2001 From: Barbara Darques Barros Date: Thu, 20 Aug 2020 20:20:06 +0200 Subject: [PATCH] Add SSRU layer and decoder (#851) - Adds layers.SSRU, which implements a Simpler Simple Recurrent Unit as described by Kim et al, 2019. - Adds ssru_transformer option to --decoder, which enables the usage of SSRUs as a replacement for the decoder-side self-attention layers. - Reduces the number of arguments for MultiHeadSelfAttention.hybrid_forward(). previous_keys and previous_values should now be input together as previous_states, a list containing two symbols. --- CHANGELOG.md | 15 ++ sockeye/__init__.py | 2 +- sockeye/arguments.py | 4 +- sockeye/constants.py | 10 +- sockeye/decoder.py | 81 ++++++----- sockeye/layers.py | 200 +++++++++++++++++++++++++- sockeye/train.py | 6 +- sockeye/transformer.py | 77 ++++++---- test/integration/test_seq_copy_int.py | 20 +-- test/system/test_seq_copy_sys.py | 13 +- 10 files changed, 344 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b82c6d6f..b94091236 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,21 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [2.1.17] + +### Added + +- Added `layers.SSRU`, which implements a Simpler Simple Recurrent Unit as described in +Kim et al, "From Research to Production and Back: Ludicrously Fast Neural Machine Translation" WNGT 2019. + +- Added `ssru_transformer` option to `--decoder`, which enables the usage of SSRUs as a replacement for the decoder-side self-attention layers. + +### Changed + +- Reduced the number of arguments for `MultiHeadSelfAttention.hybrid_forward()`. + `previous_keys` and `previous_values` should now be input together as `previous_states`, a list containing two symbols. + + ## [2.1.16] ### Fixed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 3481e460c..6061b5e5d 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '2.1.16' +__version__ = '2.1.17' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 65848a83b..b0545b846 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -634,7 +634,9 @@ def add_model_parameters(params): model_params.add_argument('--decoder', choices=C.DECODERS, default=C.TRANSFORMER_TYPE, - help="Type of encoder. Default: %(default)s.") + help="Type of decoder. Default: %(default)s. " + "'ssru_transformer' uses Simpler Simple Recurrent Units (Kim et al, 2019) " + "as replacement for self-attention layers.") model_params.add_argument('--num-layers', type=multiple_values(num_values=2, greater_or_equal=1), diff --git a/sockeye/constants.py b/sockeye/constants.py index 9b6bae757..921aa2f39 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -51,6 +51,9 @@ DEFAULT_OUTPUT_LAYER_PREFIX = "target_output_" LENRATIOS_OUTPUT_LAYER_PREFIX = "length_ratio_layer_" +# SSRU +SSRU_PREFIX = "ssru_" + # embedding prefixes SOURCE_EMBEDDING_PREFIX = "source_" + EMBEDDING_PREFIX SOURCE_POSITIONAL_EMBEDDING_PREFIX = "source_pos_" + EMBEDDING_PREFIX @@ -72,8 +75,10 @@ # available encoders ENCODERS = [TRANSFORMER_TYPE] -# available decoder -DECODERS = [TRANSFORMER_TYPE] +# TODO replace options list (e.g ENCODERS, DECODERS, ...) with Enum classes +# available decoders +SSRU_TRANSFORMER = SSRU_PREFIX + TRANSFORMER_TYPE +DECODERS = [TRANSFORMER_TYPE, SSRU_TRANSFORMER] # positional embeddings NO_POSITIONAL_EMBEDDING = "none" @@ -116,6 +121,7 @@ # default decoder prefixes TRANSFORMER_DECODER_PREFIX = DECODER_PREFIX + "transformer_" +TRANSFORMER_SSRU_DECODER_PREFIX = DECODER_PREFIX + SSRU_TRANSFORMER # Activation types RELU = "relu" diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 222fceece..cbb026923 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -16,6 +16,7 @@ """ import logging from abc import abstractmethod +from itertools import islice from typing import Dict, List, Optional, Tuple, Union, Type import mxnet as mx @@ -66,7 +67,7 @@ def get_decoder(cls, config: DecoderConfig, inference_only: bool, prefix: str, d Creates decoder based on config type. :param config: Decoder config. - :param inference_ony: Create a decoder that is only used for inference. + :param inference_only: Create a decoder that is only used for inference. :param prefix: Prefix to prepend for decoder. :param dtype: Data type for weights. @@ -114,7 +115,7 @@ def get_num_hidden(self): class TransformerDecoder(Decoder, mx.gluon.HybridBlock): """ Transformer decoder as in Vaswani et al, 2017: Attention is all you need. - In training, computation scores for each position of the known target sequence are compouted in parallel, + In training, computation scores for each position of the known target sequence are computed in parallel, yielding most of the speedup. At inference time, the decoder block is evaluated again and again over a maximum length input sequence that is initially filled with zeros and grows during beam search with predicted tokens. Appropriate masking at every @@ -147,7 +148,8 @@ def __init__(self, name="bias") self.layers = mx.gluon.nn.HybridSequential() for i in range(config.num_layers): - self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype)) + self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype, + inference_only=self.inference_only)) self.final_process = transformer.TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, @@ -164,7 +166,9 @@ def state_structure(self) -> str: structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers * 2 else: structure += C.STEP_STATE + C.ENCODER_STATE + C.BIAS_STATE - structure += C.DECODER_STATE * self.config.num_layers * 2 + + total_num_states = sum(layer.num_state_tensors for layer in self.layers) + structure += C.DECODER_STATE * total_num_states return structure @@ -177,7 +181,7 @@ def init_state_from_encoder(self, At inference, this method returns the following state tuple: valid length bias, step state, [projected encoder attention keys, projected encoder attention values] * num_layers, - [self attention dummies] * num_layers. + [autoregressive state dummies] * num_layers. :param encoder_outputs: Encoder outputs. Shape: (batch, source_length, encoder_dim). :param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,). @@ -202,14 +206,13 @@ def init_state_from_encoder(self, states = [step, encoder_outputs, source_mask] batch_size = encoder_outputs.shape[0] - # shape: (batch, heads, length, depth_per_head) - self_att_key_value_dummies = [mx.nd.zeros((batch_size, - self.config.attention_heads, - 1, - self.config.model_size // self.config.attention_heads), - ctx=encoder_outputs.context, - dtype=encoder_outputs.dtype)] * self.config.num_layers * 2 - states += self_att_key_value_dummies + dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size), + ctx=encoder_outputs.context, + dtype=encoder_outputs.dtype) + for layer in self.layers + for _ in range(layer.num_state_tensors)] + + states += dummy_autoregr_states return states @@ -256,7 +259,7 @@ def forward(self, step_input, states): states = [steps] + states # run decoder op - target, self_attention_key_values = super().forward(step_input, states) + target, autoregr_states = super().forward(step_input, states) if is_inference: # During inference, length dimension of decoder output has size 1, squeeze it @@ -269,11 +272,11 @@ def forward(self, step_input, states): if self.inference_only: # pass in cached encoder states encoder_attention_keys_values = states[2:2 + self.config.num_layers * 2] - new_states = [step, states[1]] + encoder_attention_keys_values + self_attention_key_values + new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states else: encoder_outputs = states[1] source_mask = states[2] - new_states = [step, encoder_outputs, source_mask] + self_attention_key_values + new_states = [step, encoder_outputs, source_mask] + autoregr_states assert len(new_states) == len(states) else: @@ -281,27 +284,24 @@ def forward(self, step_input, states): return target, new_states def hybrid_forward(self, F, step_input, states): + mask = None if self.inference_only: - # No autoregressive mask needed for decoding - mask = None - steps, source_mask, *other = states - source_encoded = None # use constant pre-computed key value projections from the states enc_att_kv = other[:self.config.num_layers * 2] enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)] - self_att_kv = other[self.config.num_layers * 2:] - self_att_kv = [self_att_kv[i:i + 2] for i in range(0, len(self_att_kv), 2)] + autoregr_states = other[self.config.num_layers * 2:] else: - mask = self.autoregressive_bias(step_input) # mask: (1, length, length) - - steps, source_encoded, source_mask, *other = states - - self_att_kv = other - self_att_kv = [self_att_kv[i:i + 2] for i in range(0, len(self_att_kv), 2)] - + if any(layer.needs_mask for layer in self.layers): + mask = self.autoregressive_bias(step_input) # mask: (1, length, length) + steps, source_encoded, source_mask, *autoregr_states = states enc_att_kv = [(None, None) for _ in range(self.config.num_layers)] + if any(layer.num_state_tensors > 1 for layer in self.layers): + # separates autoregressive states by layer + states_iter = iter(autoregr_states) + autoregr_states = [list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers] + # Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, 1, seq_len) source_mask = F.expand_dims(F.reshape(source_mask, shape=(-3, -2)), axis=1) @@ -311,18 +311,21 @@ def hybrid_forward(self, F, step_input, states): if self.config.dropout_prepost > 0.0: target = F.Dropout(data=target, p=self.config.dropout_prepost) - new_self_att_kv = [] # type: List[Tuple] - for layer, (self_att_k, self_att_v), (enc_att_k, enc_att_v) in zip(self.layers, self_att_kv, enc_att_kv): - target, new_self_att_k, new_self_att_v = layer(target, - mask, - source_encoded, - source_mask, - self_att_k, self_att_v, - enc_att_k, enc_att_v) - new_self_att_kv += [new_self_att_k, new_self_att_v] + new_autoregr_states = [] + for layer, layer_autoregr_state, (enc_att_k, enc_att_v) in zip(self.layers, autoregr_states, enc_att_kv): + target, new_layer_autoregr_state = layer(target, + mask, + source_encoded, + source_mask, + layer_autoregr_state, + enc_att_k, enc_att_v) + + new_autoregr_states += [*new_layer_autoregr_state] + # NOTE: the list expansion is needed in order to handle both a tuple (of Symbols) and a Symbol as a new state + target = self.final_process(target, None) - return target, new_self_att_kv + return target, new_autoregr_states def get_num_hidden(self): return self.config.model_size diff --git a/sockeye/layers.py b/sockeye/layers.py index d0ad783b2..52d1ff3a9 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -12,7 +12,8 @@ # permissions and limitations under the License. import logging -from typing import Optional, Union, Tuple +from abc import abstractmethod +from typing import Optional, Union, Tuple, List from functools import lru_cache import mxnet as mx @@ -418,7 +419,45 @@ def _attend(self, return contexts -class MultiHeadSelfAttention(MultiHeadAttentionBase): +class AutoregressiveLayer(mx.gluon.HybridBlock): + @property + @abstractmethod + def prefix(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def num_state_tensors(self) -> int: + """ Number of state tensors returned by the layer """ + raise NotImplementedError + + @property + @abstractmethod + def needs_mask(self) -> bool: + """ Whether the layer makes use of a mask tensor or not """ + raise NotImplementedError + + @abstractmethod + def get_state_shape(self, batch_size) -> Tuple: + """ + :param batch_size: current batch size + :return: dimensions of each output state (assuming all of them have the same shape) + """ + raise NotImplementedError + + @abstractmethod + def hybrid_forward(self, F, inputs: mx.sym.Symbol, previous_states: mx.sym.Symbol, *args) -> Tuple: + """ + :param F: ndarray or Symbol + :param inputs: layer input + :param previous_states: Symbol or list of Symbols + :param args: layer-specific arguments and/or arguments to be ignored + :return: layer output and new states + """ + raise NotImplementedError + + +class MultiHeadSelfAttention(MultiHeadAttentionBase, AutoregressiveLayer): """ Multi-head self-attention. Independent linear projections of inputs serve as queries, keys, and values for the attention. @@ -443,12 +482,34 @@ def __init__(self, with self.name_scope(): self.ff_in = quantization.QuantizableDense(in_units=depth_att, units=depth_att * 3, flatten=False, use_bias=False, prefix='i2h_', dtype=dtype) + @property + def prefix(self) -> str: + return "att_self_" + + @property + def num_state_tensors(self) -> int: + """ Number of state tensors returned by the layer """ + return 2 + + @property + def needs_mask(self) -> bool: + """ Whether the layer makes use of a mask tensor or not """ + return True + + def get_state_shape(self, batch_size: int) -> Tuple: + """ + :param batch_size: current batch size + :return: dimensions of each output state (assuming all of them have the same shape) + """ + # shape: (batch, heads, length, depth_per_head) + return batch_size, self.heads, 1, self.depth_out // self.heads + def hybrid_forward(self, F, inputs: mx.sym.Symbol, + previous_states: List[mx.sym.Symbol], input_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, - previous_keys: Optional[mx.sym.Symbol] = None, - previous_values: Optional[mx.sym.Symbol] = None): # mypy: ignore + *args): # mypy: ignore """ Computes multi-head attention on a set of inputs, serving as queries, keys, and values. If sequence lengths are provided, they will be used to mask the attention scores. @@ -459,8 +520,8 @@ def hybrid_forward(self, F, :param inputs: Input Data. Shape: (batch, max_length, input_depth). :param input_lengths: Optional lengths of inputs to mask attention scores. Shape: (batch, 1). :param bias: Optional 3d bias tensor to mask attention scores. - :param previous_keys: Optional previous input projections of keys. Shape: (batch, max_length+1, depth_att). - :param previous_values: Optional previous input projections of values. Shape: (batch, max_length+1, depth_att). + :param previous_states: Optional list with two Symbols - previous input's keys and values. + Shape: 2 * (batch, max_length+1, depth_att). :return: Symbol of shape (batch, max_length, output_depth). """ # combined: (batch, max_length, depth * 3) @@ -478,6 +539,8 @@ def hybrid_forward(self, F, values = split_heads(F, values, self.depth_per_head, self.heads) updated_keys = keys + + previous_keys, previous_values = previous_states if previous_keys is not None: updated_keys = F.concat(previous_keys, keys, dim=2) keys = _remove_first_step(F, updated_keys) @@ -732,3 +795,128 @@ def hybrid_forward(self, F, data, steps, weight): # pylint: disable=arguments-d data = data * (self.num_embed ** 0.5) return F.broadcast_add(data, pos_embedding) + + +class SSRU(AutoregressiveLayer): + """ + Simpler Simple Recurrent Unit + + Kim et al, "From Research to Production and Back: Ludicrously Fast Neural Machine Translation" WNGT 2019 + + Variant of an LSTM cell aimed at reducing computational dependency across time steps. + Formally described as: + + (1) f[t] = sigmoid(W1[t] * x[t] + b[t]) + (2) c[t] = f[t] . c[t-1] + (1 - f[t]) . W2[t] * x[t] + (3) h = ReLU(c[t]) + + where: + . represents elementwise multiplication; + x[t] is the input at time step t; + f[t] is the output of the forget gate at time step t; + c[t] is the cell state at time step t; + h is the output of the unit. + + :param model_size: number of hidden units + :param inference_only: flag used to indicate execution at inference time + :param prefix: prefix prepended to the names of internal Symbol instances + :param dtype: data type + """ + def __init__(self, + model_size: int, + inference_only: bool, + prefix: str = C.SSRU_PREFIX, + dtype: str = C.DTYPE_FP32) -> None: + super(SSRU, self).__init__(prefix=prefix) + + self.model_size = model_size + self.inference_only = inference_only + + self.cell_state_transform = self._inference_cell_state_transform \ + if inference_only else self._training_cell_state_transform + + with self.name_scope(): + self.forget_gate = quantization.QuantizableDense(in_units=model_size, + units=model_size, + activation="sigmoid", + flatten=False, + prefix="forget_gate_", + dtype=dtype) + + self.linear = quantization.QuantizableDense(in_units=model_size, + units=model_size, + use_bias=False, + flatten=False, + prefix="linear_", + dtype=dtype) + + @property + def prefix(self) -> str: + return C.SSRU_PREFIX + + @property + def num_state_tensors(self) -> int: + """ Number of state tensors returned by the layer """ + return 1 + + @property + def needs_mask(self) -> bool: + """ Whether the layer makes use of a mask tensor or not """ + return False + + def get_state_shape(self, batch_size: int) -> Tuple: + """ + :param batch_size: current batch size + :return: dimensions of each output state (assuming all of them have the same shape) + """ + if self.inference_only: + return batch_size, 1, self.model_size + else: + return batch_size, self.model_size + + @staticmethod + def _training_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: + """Update SSRU cell at training time""" + def _time_step_update(step_input_and_forget_rate, previous_step_state) -> Tuple: + """ + Recurrently update the SSRU cell state for one time step. + + :param step_input_and_forget_rate: List = [step_input, forget_rate] + :param previous_step_state: cell state at (t-1) + :return: twice the current time step state. NOTE: The first instance will be stacked in the final + foreach output and the second will be the input to the next time_step_update iteration. + """ + step_input, forget_rate = step_input_and_forget_rate # each of shape (batch_size, model_size) + current_step_state = forget_rate * previous_step_state + step_input + return current_step_state, current_step_state + + weighted_inputs = F.transpose(weighted_inputs, axes=(1, 0, 2)) # (max_length, batch, input_depth) + forget_rates = F.transpose(forget_rates, axes=(1, 0, 2)) # (max_length, batch, input_depth) + + # (max_length, batch, input_depth), (batch, input_depth) + cell_state, last_step_state = F.contrib.foreach(_time_step_update, + [weighted_inputs, forget_rates], + previous_cell_state) + + return F.transpose(cell_state, axes=(1, 0, 2)), last_step_state + + @staticmethod + def _inference_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: + """Update SSRU cell at inference time""" + new_step_state = forget_rates * previous_cell_state + weighted_inputs # (batch, 1, input_depth) + return new_step_state, new_step_state + + def hybrid_forward(self, F, inputs: mx.sym.Symbol, previous_states: mx.sym.Symbol, *args) -> Tuple: + """ + :param F: ndarray or Symbol + :param inputs: input data. Shape: (batch, max_length, input_depth). + :param previous_states: previous cell states. Shape: (batch, max_length, input_depth) + :return: cell output and new cell states. Both with shape (batch, max_length, input_depth). + """ + forget_rates = self.forget_gate(inputs) + weighted_inputs = (1 - forget_rates) * self.linear(inputs) + + cell_state, last_step_state = self.cell_state_transform(F, previous_states, weighted_inputs, forget_rates) + + return F.relu(cell_state), last_step_state + diff --git a/sockeye/train.py b/sockeye/train.py index 4a493df0d..bd2f7a609 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -429,7 +429,8 @@ def create_encoder_config(args: argparse.Namespace, postprocess_sequence=encoder_transformer_postprocess, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, - lhuc=args.lhuc is not None and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc)) + lhuc=args.lhuc is not None and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc), + decoder_type=args.decoder) encoder_num_hidden = encoder_transformer_model_size return config_encoder, encoder_num_hidden @@ -465,7 +466,8 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, lhuc=args.lhuc is not None and (C.LHUC_DECODER in args.lhuc or C.LHUC_ALL in args.lhuc), - depth_key_value=encoder_num_hidden) + depth_key_value=encoder_num_hidden, + decoder_type=args.decoder) return config_decoder diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 3193d8dc4..d764b0ec5 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import List, Optional, Tuple +from typing import Optional, Tuple import mxnet as mx @@ -37,6 +37,7 @@ def __init__(self, postprocess_sequence: str, max_seq_len_source: int, max_seq_len_target: int, + decoder_type: str = C.TRANSFORMER_TYPE, lhuc: bool = False, depth_key_value: int = 0) -> None: # type: ignore super().__init__() @@ -55,6 +56,7 @@ def __init__(self, self.max_seq_len_target = max_seq_len_target self.use_lhuc = lhuc self.depth_key_value = depth_key_value + self.decoder_type = decoder_type class TransformerEncoderBlock(mx.gluon.HybridBlock): @@ -105,7 +107,7 @@ def __init__(self, def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention - data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), None, bias, None, None) + data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], None, bias) data = self.post_self_attention(data_self_att, data) # feed-forward @@ -120,29 +122,41 @@ def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym. class TransformerDecoderBlock(mx.gluon.HybridBlock): """ - A transformer encoder block consists self-attention, encoder attention, and a feed-forward layer + A transformer decoder block consists of an autoregressive attention block, encoder attention, and a feed-forward layer with pre/post process blocks in between. """ def __init__(self, config: TransformerConfig, + inference_only: bool, prefix: str, dtype: str) -> None: super().__init__(prefix=prefix) + self.decoder_type = config.decoder_type + with self.name_scope(): - self.pre_self_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, + if self.decoder_type == C.TRANSFORMER_TYPE: + self.autoregr_layer = layers.MultiHeadSelfAttention(depth_att=config.model_size, + heads=config.attention_heads, + depth_out=config.model_size, + dropout=config.dropout_attention, + prefix="att_self_", + dtype=dtype) + elif self.decoder_type == C.SSRU_TRANSFORMER: + self.autoregr_layer = layers.SSRU(model_size=config.model_size, + inference_only=inference_only, + dtype=dtype) + else: + raise ValueError("Invalid decoder type.") + + self.pre_autoregr_layer = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, - prefix="att_self_pre_", + prefix=self.autoregr_layer.prefix + "pre_", num_hidden=config.model_size) - self.self_attention = layers.MultiHeadSelfAttention(depth_att=config.model_size, - heads=config.attention_heads, - depth_out=config.model_size, - dropout=config.dropout_attention, - prefix="att_self_", - dtype=dtype) - self.post_self_attention = TransformerProcessBlock(sequence=config.postprocess_sequence, + + self.post_autoregr_layer = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, - prefix="att_self_post_", + prefix=self.autoregr_layer.prefix + "post_", num_hidden=config.model_size) self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, @@ -180,24 +194,38 @@ def __init__(self, if config.use_lhuc: self.lhuc = layers.LHUC(config.model_size) + @property + def num_state_tensors(self) -> int: + """ Number of state tensors returned by the layer """ + return self.autoregr_layer.num_state_tensors + + @property + def needs_mask(self): + """ Whether the block makes use of a mask tensor or not """ + return self.autoregr_layer.needs_mask + + def get_states_shape(self, batch_size: int) -> Tuple: + """ + :param batch_size: current batch size + :return: dimensions of an output state (assuming all of them have the same shape) + """ + return self.autoregr_layer.get_state_shape(batch_size) + def hybrid_forward(self, F, target: mx.sym.Symbol, target_bias: mx.sym.Symbol, source: mx.sym.Symbol, source_bias: mx.sym.Symbol, - self_att_k: Optional[mx.sym.Symbol] = None, - self_att_v: Optional[mx.sym.Symbol] = None, + autoregr_states: mx.sym.Symbol, enc_att_k: Optional[mx.sym.Symbol] = None, enc_att_v: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, - mx.sym.Symbol, mx.sym.Symbol]: - # self-attention - target_self_att, keys, values = self.self_attention(self.pre_self_attention(target, None), - None, - target_bias, - self_att_k, - self_att_v) - target = self.post_self_attention(target_self_att, target) + target_autoregr, *new_autoregr_states = self.autoregr_layer(self.pre_autoregr_layer(target, None), + autoregr_states, + None, + target_bias) + + target = self.post_autoregr_layer(target_autoregr, target) # encoder attention target_enc_att = self.enc_attention(self.pre_enc_attention(target, None), @@ -206,6 +234,7 @@ def hybrid_forward(self, F, source_bias, enc_att_k, enc_att_v) + target = self.post_enc_attention(target_enc_att, target) # feed-forward @@ -215,7 +244,7 @@ def hybrid_forward(self, F, if self.lhuc: target = self.lhuc(target) - return target, keys, values + return target, new_autoregr_states class TransformerProcessBlock(mx.gluon.nn.HybridBlock): diff --git a/test/integration/test_seq_copy_int.py b/test/integration/test_seq_copy_int.py index c66f780c0..8487b8a38 100644 --- a/test/integration/test_seq_copy_int.py +++ b/test/integration/test_seq_copy_int.py @@ -42,9 +42,9 @@ _TEST_MAX_LENGTH = 20 # tuple format: (train_params, translate_params, use_prepared_data, use_source_factors) -ENCODER_DECODER_SETTINGS = [ +ENCODER_DECODER_SETTINGS_TEMPLATE = [ # Basic transformer, nbest=2 decoding - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -55,7 +55,7 @@ "--beam-size 2 --nbest-size 2", False, 0), # Basic transformer w/ prepared data & greedy decoding - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -66,7 +66,7 @@ "--beam-size 1", True, 0), # Basic transformer with source factor, beam-search-stop first decoding - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -78,7 +78,7 @@ "--beam-size 2 --beam-search-stop first", True, 3), # Basic transformer with LHUC - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -89,7 +89,7 @@ "--beam-size 2", False, 0), # Basic transformer and length ratio prediction, and learned brevity penalty during inference - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -102,7 +102,7 @@ " --brevity-penalty-type learned --brevity-penalty-weight 1.0", True, 0), # Basic transformer and absolute length prediction, and constant brevity penalty during inference - ("--encoder transformer --decoder transformer" + ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" " --transformer-feed-forward-num-hidden 16" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" @@ -116,6 +116,10 @@ False, 0), ] +ENCODER_DECODER_SETTINGS = [(train_params.format(decoder=decoder), *other_params) + for decoder in C.DECODERS + for (train_params, *other_params) in ENCODER_DECODER_SETTINGS_TEMPLATE] + @pytest.mark.parametrize("train_params, translate_params, use_prepared_data, n_source_factors", ENCODER_DECODER_SETTINGS) @@ -263,7 +267,7 @@ def _test_mc_dropout(model_path: str): model, _, _ = load_model(model_folder=model_path, context=[mx.cpu()], mc_dropout=True, inference_only=True, hybridize=True) # Ensure the model has some dropout turned on - config_blocks = [block for _, block in model.config.__dict__.items() if isinstance(block, Config)] + config_blocks = [block for _, block in model.config.__dict__.items() if isinstance(block, Config)] dropout_settings = {setting: val for block in config_blocks for setting, val in block.__dict__.items() if "dropout" in setting} assert any(s > 0.0 for s in dropout_settings.values()) diff --git a/test/system/test_seq_copy_sys.py b/test/system/test_seq_copy_sys.py index 149ee2244..749be8b32 100644 --- a/test/system/test_seq_copy_sys.py +++ b/test/system/test_seq_copy_sys.py @@ -147,7 +147,18 @@ def test_seq_copy(name, train_params, translate_params, use_prepared_data, perpl "--beam-size 1", True, 3, 1.03, - 0.96) + 0.96), + ("Sort:transformer:ssru_transformer:batch_word", + "--encoder transformer --decoder ssru_transformer" + " --max-seq-len 10 --batch-size 90 --update-interval 1 --batch-type word --batch-sentences-multiple-of 1" + " --max-updates 6000" + " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 32 --num-embed 32" + " --transformer-dropout-attention 0.0 --transformer-dropout-act 0.0 --transformer-dropout-prepost 0.0" + " --transformer-feed-forward-num-hidden 64" + COMMON_TRAINING_PARAMS, + "--beam-size 1", + True, 0, + 1.03, + 0.97) ]) def test_seq_sort(name, train_params, translate_params, use_prepared_data, n_source_factors, perplexity_thresh, bleu_thresh):