diff --git a/CHANGELOG.md b/CHANGELOG.md index f47f1c94d..980642d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [2.2.0] + +### Changed + +- Replaced multi-head attention with [interleaved_matmul_encdec](https://github.com/apache/incubator-mxnet/pull/16408) operators, which removes previously needed transposes and improves performance. + +- Beam search states and model layers now assume time-major format. ## [2.1.26] diff --git a/MANIFEST.in b/MANIFEST.in index e307a5fa7..7195a29e1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,12 +7,12 @@ include pylintrc include .flake8 include typechecked-files include test/data/config_with_missing_attributes.yaml -include test/data/model_2.1.x/config -include test/data/model_2.1.x/params.best -include test/data/model_2.1.x/model_input -include test/data/model_2.1.x/vocab* -include test/data/model_2.1.x/version -include test/data/model_2.1.x/README.md +include test/data/model_2.2.x/config +include test/data/model_2.2.x/params.best +include test/data/model_2.2.x/model_input +include test/data/model_2.2.x/vocab* +include test/data/model_2.2.x/version +include test/data/model_2.2.x/README.md include sockeye/git_version.py include *.bib recursive-include .github * diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 34f6d9f04..445bba797 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '2.1.26' +__version__ = '2.2.0' diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 3ce759fc5..0737aa169 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -417,17 +417,17 @@ def _repeat_states(states: List, beam_size: int, state_structure: List) -> List: assert len(states) == len(flat_structure), "Number of states do not match the defined state structure" for state, state_format in zip(states, flat_structure): if state_format == C.STEP_STATE or state_format == C.BIAS_STATE: + # Steps and source_bias have batch dimension on axis 0 repeat_axis = 0 elif state_format == C.DECODER_STATE or state_format == C.ENCODER_STATE: - # TODO: Change repeat axis to 1 when interleaved multihead attention is implemented - repeat_axis = 0 + # Decoder and encoder layer states have batch dimension on axis 1 + repeat_axis = 1 else: raise ValueError("Provided state format %s not recognized." % state_format) repeated_state = state.repeat(repeats=beam_size, axis=repeat_axis) repeated_states.append(repeated_state) return repeated_states - class SortStates(mx.gluon.HybridBlock): def __init__(self, state_structure, prefix): @@ -439,10 +439,11 @@ def hybrid_forward(self, F, best_hyp_indices, *states): assert len(states) == len(self.flat_structure), "Number of states do not match the defined state structure" for state, state_format in zip(states, self.flat_structure): if state_format == C.STEP_STATE or state_format == C.BIAS_STATE: + # Steps and source_bias have batch dimension on axis 0 sorted_state = F.take(state, best_hyp_indices) elif state_format == C.DECODER_STATE: - # TODO: Change take axis to 1 when interleaved multihead attention is implemented - sorted_state = F.take(state, best_hyp_indices) + # Decoder and encoder layer states have batch dimension on axis 1 + sorted_state = F.take(state, best_hyp_indices, axis=1) elif state_format == C.ENCODER_STATE: # No need for takes on encoder layer states sorted_state = state diff --git a/sockeye/decoder.py b/sockeye/decoder.py index cbb026923..9a28d0d1b 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -163,7 +163,7 @@ def state_structure(self) -> str: """ structure = '' if self.inference_only: - structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers * 2 + structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers else: structure += C.STEP_STATE + C.ENCODER_STATE + C.BIAS_STATE @@ -197,13 +197,11 @@ def init_state_from_encoder(self, states = [step, source_mask] for layer in self.layers: - encoder_attention_keys, encoder_attention_values = \ - layer.enc_attention.project_and_isolate_heads(mx.nd, encoder_outputs) - states.append(encoder_attention_keys) - states.append(encoder_attention_values) + enc_att_kv = layer.enc_attention.ff_kv(encoder_outputs) + states.append(mx.nd.transpose(enc_att_kv, axes=(1, 0, 2))) else: # NO encoder projection caching - states = [step, encoder_outputs, source_mask] + states = [step, mx.nd.transpose(encoder_outputs, axes=(1, 0, 2)), source_mask] batch_size = encoder_outputs.shape[0] dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size), @@ -271,7 +269,7 @@ def forward(self, step_input, states): if self.inference_only: # pass in cached encoder states - encoder_attention_keys_values = states[2:2 + self.config.num_layers * 2] + encoder_attention_keys_values = states[2:2 + self.config.num_layers] new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states else: encoder_outputs = states[1] @@ -288,14 +286,13 @@ def hybrid_forward(self, F, step_input, states): if self.inference_only: steps, source_mask, *other = states source_encoded = None # use constant pre-computed key value projections from the states - enc_att_kv = other[:self.config.num_layers * 2] - enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)] - autoregr_states = other[self.config.num_layers * 2:] + enc_att_kv = other[:self.config.num_layers] + autoregr_states = other[self.config.num_layers:] else: if any(layer.needs_mask for layer in self.layers): mask = self.autoregressive_bias(step_input) # mask: (1, length, length) steps, source_encoded, source_mask, *autoregr_states = states - enc_att_kv = [(None, None) for _ in range(self.config.num_layers)] + enc_att_kv = [None for _ in range(self.config.num_layers)] if any(layer.num_state_tensors > 1 for layer in self.layers): # separates autoregressive states by layer @@ -307,23 +304,25 @@ def hybrid_forward(self, F, step_input, states): # target: (batch_size, length, model_size) target = self.pos_embedding(step_input, steps) + # (length, batch_size, model_size) + target = F.transpose(target, axes=(1, 0, 2)) if self.config.dropout_prepost > 0.0: target = F.Dropout(data=target, p=self.config.dropout_prepost) new_autoregr_states = [] - for layer, layer_autoregr_state, (enc_att_k, enc_att_v) in zip(self.layers, autoregr_states, enc_att_kv): + for layer, layer_autoregr_state, layer_enc_att_kv in zip(self.layers, autoregr_states, enc_att_kv): target, new_layer_autoregr_state = layer(target, mask, source_encoded, source_mask, layer_autoregr_state, - enc_att_k, enc_att_v) + layer_enc_att_kv) new_autoregr_states += [*new_layer_autoregr_state] - # NOTE: the list expansion is needed in order to handle both a tuple (of Symbols) and a Symbol as a new state target = self.final_process(target, None) + target = F.transpose(target, axes=(1, 0, 2)) return target, new_autoregr_states diff --git a/sockeye/encoder.py b/sockeye/encoder.py index ec4ea41ea..ffb3cca09 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -330,11 +330,13 @@ def hybrid_forward(self, F, data, valid_length): # (batch_size * heads, 1, seq_len) bias = F.expand_dims(self.valid_length_mask(data, valid_length), axis=1) + data = F.transpose(data, axes=(1, 0, 2)) for block in self.layers: data = block(data, bias) data = self.final_process(data, None) + data = F.transpose(data, axes=(1, 0, 2)) return data, valid_length def get_num_hidden(self) -> int: diff --git a/sockeye/layers.py b/sockeye/layers.py index 52d1ff3a9..b76191bfb 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -257,38 +257,6 @@ def hybrid_forward(self, F, source_encoded, source_encoded_length): return F.squeeze(data) -def split_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sym.Symbol: - """ - Returns a symbol with heads as second dimension and channel depth / number of heads as last dimension. - - :param x: Symbol of shape (batch, length, depth). - :param depth_per_head: Depth per head. - :param heads: Number of heads. - :return: Symbol of shape (batch, heads, length, depth_per_heads). - """ - # (batch, length, heads, depth_per_head) - x = F.reshape(x, shape=(0, -1, heads, depth_per_head)) - # (batch, heads, length, depth/heads) - return F.transpose(x, axes=(0, 2, 1, 3)) - - -def combine_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sym.Symbol: - """ - Returns a symbol with both batch & length, and head & depth dimensions combined. - - :param x: Symbol of shape (batch * heads, length, depth_per_head). - :param depth_per_head: Depth per head. - :param heads: Number of heads. - :return: Symbol of shape (batch, length, depth). - """ - # (batch, heads, length, depth_per_head) - x = F.reshape(x, shape=(-4, -1, heads, 0, depth_per_head)) - # (batch, length, heads, depth_per_head) - x = F.transpose(x, axes=(0, 2, 1, 3)) - # (batch, length, depth) - return F.reshape(x, shape=(-1, 0, depth_per_head * heads)) - - def broadcast_to_heads(F, x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads: bool = True) -> mx.sym.Symbol: """ Broadcasts batch-major input of shape (batch, d1 ... dn-1) to (batch*heads, d1 ... dn-1). @@ -325,9 +293,10 @@ def cast(self, dtype): self._dtype = dtype super().cast(dtype) - def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None): - # (n, lq, lk) - logits = F.batch_dot(lhs=queries, rhs=keys, transpose_b=True) + def hybrid_forward(self, F, queries, key_values, heads, lengths=None, bias=None): + + # (n*h, lq, lk) + logits = F.contrib.interleaved_matmul_encdec_qk(queries, key_values, heads=heads) # TODO(fhieber): consider softmax with length argument once available. # TODO(fhieber: Also see https://github.com/dmlc/gluon-nlp/pull/910 @@ -347,9 +316,11 @@ def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None): probs = F.softmax(logits, axis=-1) probs = F.Dropout(probs, p=self.dropout) if self.dropout > 0.0 else probs - - # (n, lq, lk) x (n, lk, dv) -> (n, lq, dv) - return F.batch_dot(lhs=probs, rhs=values) + + # key_values: (lk, n, dv * 2) + # probs: (n*h, lq, lk) + # result: (n, lq, dv) + return F.contrib.interleaved_matmul_encdec_valatt(key_values, probs, heads=heads) class MultiHeadAttentionBase(mx.gluon.HybridBlock): @@ -385,35 +356,25 @@ def __init__(self, def _attend(self, F, queries: mx.sym.Symbol, - keys: mx.sym.Symbol, - values: mx.sym.Symbol, + key_values: mx.sym.Symbol, lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: """ Returns context vectors of multi-head dot attention. - :param queries: Query tensor. Shape: (batch_size, heads, query_max_length, depth_per_head). - :param keys: Keys. Shape: (batch_size, heads, memory_max_length, depth_per_head). - :param values: Values. Shape: (batch_size, heads, memory_max_length, depth_per_head). + :param queries: Query tensor. Shape: (query_max_length, batch_size, depth). + :param key_values: Keys. Shape: (memory_max_length, batch_size, depth * 2). :param lengths: Optional lengths of keys. Shape: (batch_size,). :param bias: Optional 3d bias. :return: Context vectors. Shape: (batch_size, query_max_length, output_depth). """ - # fold head dimension into batch dimension - # (batch*heads, length, depth/heads) - queries = F.reshape(queries, shape=(-3, -1, self.depth_per_head)) - keys = F.reshape(keys, shape=(-3, -1, self.depth_per_head)) - values = F.reshape(values, shape=(-3, -1, self.depth_per_head)) lengths = broadcast_to_heads(F, lengths, self.heads, ndim=1, fold_heads=True) if lengths is not None else lengths - # (batch*heads, query_max_length, depth_per_head) - contexts = self.dot_att(queries, keys, values, lengths, bias) + # (query_max_length, batch, depth) + contexts = self.dot_att(queries, key_values, self.heads, lengths, bias) - # (batch, query_max_length, depth) - contexts = combine_heads(F, contexts, self.depth_per_head, self.heads) - - # contexts: (batch, query_max_length, output_depth) + # (query_max_length, batch, output_depth) contexts = self.ff_out(contexts) return contexts @@ -489,7 +450,7 @@ def prefix(self) -> str: @property def num_state_tensors(self) -> int: """ Number of state tensors returned by the layer """ - return 2 + return 1 @property def needs_mask(self) -> bool: @@ -501,12 +462,12 @@ def get_state_shape(self, batch_size: int) -> Tuple: :param batch_size: current batch size :return: dimensions of each output state (assuming all of them have the same shape) """ - # shape: (batch, heads, length, depth_per_head) - return batch_size, self.heads, 1, self.depth_out // self.heads + # shape: (length, batch, key_depth + value_depth) + return 1, batch_size, self.depth_out * 2 def hybrid_forward(self, F, inputs: mx.sym.Symbol, - previous_states: List[mx.sym.Symbol], + previous_states: Optional[mx.sym.Symbol] = None, input_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, *args): # mypy: ignore @@ -524,42 +485,17 @@ def hybrid_forward(self, F, Shape: 2 * (batch, max_length+1, depth_att). :return: Symbol of shape (batch, max_length, output_depth). """ - # combined: (batch, max_length, depth * 3) - combined = self.ff_in(inputs) - # split into query, keys and values - # (batch, max_length, depth) - # pylint: disable=unbalanced-tuple-unpacking - queries, keys, values = F.split(combined, num_outputs=3, axis=2) - - # scale by sqrt(depth_per_head) - queries = queries * (self.depth_per_head ** -0.5) - # (batch, heads, length, depth/heads) - queries = split_heads(F, queries, self.depth_per_head, self.heads) - keys = split_heads(F, keys, self.depth_per_head, self.heads) - values = split_heads(F, values, self.depth_per_head, self.heads) - - updated_keys = keys - previous_keys, previous_values = previous_states - if previous_keys is not None: - updated_keys = F.concat(previous_keys, keys, dim=2) - keys = _remove_first_step(F, updated_keys) + proj = self.ff_in(inputs) + queries, kv_1, kv_2 = F.split(proj, num_outputs=3, axis=2) + states = F.concat(kv_1, kv_2, dim=2) - updated_values = values - if previous_values is not None: - updated_values = F.concat(previous_values, values, dim=2) - values = _remove_first_step(F, updated_values) + updated_states = states + if previous_states is not None: + updated_states = F.concat(previous_states, states, dim=0) + states = F.slice(updated_states, begin=(1, None, None), end=(None, None, None)) - return self._attend(F, queries, keys, values, lengths=input_lengths, bias=bias), updated_keys, updated_values - - -def _remove_first_step(F, data): - """ - :param F: MXNet namespace. - :param data: Input data. Shape: (batch, heads, length, num_hidden). - :return: Output data. Shape: (batch, heads, length[1:], num_hidden - """ - return F.slice(data, begin=(None, None, 1, None), end=(None, None, None, None)) + return self._attend(F, queries, states, lengths=input_lengths, bias=bias), updated_states class MultiHeadAttention(MultiHeadAttentionBase): @@ -587,56 +523,32 @@ def __init__(self, with self.name_scope(): self.ff_q = quantization.QuantizableDense(in_units=depth_out, units=depth_att, flatten=False, use_bias=False, prefix='q2h_', dtype=dtype) - self.ff_k = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='k2h_', dtype=dtype) - self.ff_v = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='v2h_', dtype=dtype) - - def project_and_isolate_heads(self, F, memory: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: - """ - Projects memory into keys and values, and separates attention heads dimension. - - :param memory: Memory tensor. Shape: (batch, memory_max_length, input_depth). - :return: Symbol of shape (batch, heads, memory_max_length, depth_per_head). - """ - keys = self.ff_k(memory) - values = self.ff_v(memory) - keys = split_heads(F, keys, depth_per_head=self.depth_per_head, heads=self.heads) - values = split_heads(F, values, depth_per_head=self.depth_per_head, heads=self.heads) - return keys, values + self.ff_kv = quantization.QuantizableDense(in_units=depth_key_value, units=2*depth_att, flatten=False, use_bias=False, prefix='kv2h_', dtype=dtype) def hybrid_forward(self, F, queries: mx.sym.Symbol, memory: mx.sym.Symbol, memory_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, - projected_memory_keys: Optional[mx.sym.Symbol] = None, - projected_memory_values: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: # mypy: ignore + projected_memory_kv: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: # mypy: ignore """ Computes multi-head attention for queries given a memory tensor. If sequence lengths are provided, they will be used to mask the attention scores. A bias mask may also be used to mask the attention scores. - Returns a symbol of shape (batch, max_length, output_depth). + Returns a symbol of shape (max_length, batch, output_depth). - :param queries: Query tensor. Shape: (batch, query_max_length, input_depth). - :param memory: Memory data to attend to. Shape: (batch, memory_max_length, input_depth). + :param queries: Query tensor. Shape: (query_max_length, batch, input_depth). + :param memory: Memory data to attend to. Shape: (memory_max_length, batch, input_depth). :param memory_lengths: Optional lengths of memory to mask attention scores. Shape: (batch, 1). :param bias: Optional 3d bias tensor to mask attention scores. - :param projected_memory_keys: Optional previously projected memory keys. - :param projected_memory_values: Optional previously projected memory values. - :return: Symbol of shape (batch, query_seq_len, output_depth). + :param projected_memory_kv: Optional previously projected memory keys and values. + :return: Symbol of shape (query_seq_len, batch, output_depth). """ - # (batch, query_max_length, depth) - queries = self.ff_q(queries) - # scale by sqrt(depth_per_head) - queries = queries * (self.depth_per_head ** -0.5) - # (batch, heads, length, depth/heads) - queries = split_heads(F, queries, self.depth_per_head, self.heads) - if projected_memory_keys is not None and projected_memory_values is not None: - keys, values = projected_memory_keys, projected_memory_values - else: - keys, values = self.project_and_isolate_heads(F, memory) + queries = self.ff_q(queries) + kv = projected_memory_kv if projected_memory_kv is not None else self.ff_kv(memory) - return self._attend(F, queries, keys, values, bias=bias, lengths=memory_lengths) + return self._attend(F, queries, kv, bias=bias, lengths=memory_lengths) class PlainDotAttention(mx.gluon.HybridBlock): @@ -652,14 +564,14 @@ def hybrid_forward(self, F, queries, memory, memory_lengths): """ Returns a symbol of shape (batch, max_length, output_depth). - :param queries: Symbol of shape (batch, queries_max_length, input_depth). - :param memory: Symbol of shape (batch, memory_max_length, input_depth). + :param queries: Symbol of shape (queries_max_length, batch, input_depth). + :param memory: Symbol of shape (memory_max_length, batch, input_depth). :param memory_lengths: Symbol of shape (batch, 1). - :return: Symbol of shape (batch, queries_max_length, output_depth). + :return: Symbol of shape (queries_max_length, batch, output_depth). """ - # (batch*heads, queries_max_length, depth_per_head) - return self.dot_att(queries, memory, memory, memory_lengths, None) + # (queries_max_length, batch, output_depth) + return self.dot_att(queries, memory, 1, memory_lengths, None) class ProjectedDotAttention(mx.gluon.HybridBlock): @@ -688,26 +600,19 @@ def hybrid_forward(self, F, """ Apply project, apply dot attention and return new context vectors. - :param queries: Symbol of shape (batch, queries_max_length, input_num_hidden). - :param memory: Symbol of shape (batch, memory_max_length, input_num_hidden). + :param queries: Symbol of shape (queries_max_length, batch, input_num_hidden). + :param memory: Symbol of shape (memory_max_length, batch, input_num_hidden). :param memory_lengths: Symbol of shape (batch, 1). - :return: Symbol of shape (batch, queries_max_length, num_hidden). + :return: Symbol of shape (queries_max_length, batch, num_hidden). """ - # (batch, memory_max_length, num_hidden * 2) + # (memory_max_length, batch, num_hidden * 2) combined = self.kv2h(memory) - # split into keys and values - # pylint: disable=unbalanced-tuple-unpacking - keys, values = F.split(data=combined, num_outputs=2, axis=2) - - # (batch, queries_max_length, num_hidden) + # (queries_max_length, batch, num_hidden) queries = self.q2h(queries) - # scale by sqrt(num_hidden) - queries = queries * (self.num_hidden ** -0.5) - - # (batch, queries_max_length, num_hidden) - contexts = self.dot_att(queries, keys, values, memory_lengths, None) + # (queries_max_length, batch, num_hidden) + contexts = self.dot_att(queries, combined, 1, memory_lengths, None) return contexts @@ -869,10 +774,7 @@ def get_state_shape(self, batch_size: int) -> Tuple: :param batch_size: current batch size :return: dimensions of each output state (assuming all of them have the same shape) """ - if self.inference_only: - return batch_size, 1, self.model_size - else: - return batch_size, self.model_size + return 1, batch_size, self.model_size @staticmethod def _training_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: @@ -890,28 +792,26 @@ def _time_step_update(step_input_and_forget_rate, previous_step_state) -> Tuple: current_step_state = forget_rate * previous_step_state + step_input return current_step_state, current_step_state - weighted_inputs = F.transpose(weighted_inputs, axes=(1, 0, 2)) # (max_length, batch, input_depth) - forget_rates = F.transpose(forget_rates, axes=(1, 0, 2)) # (max_length, batch, input_depth) # (max_length, batch, input_depth), (batch, input_depth) cell_state, last_step_state = F.contrib.foreach(_time_step_update, [weighted_inputs, forget_rates], - previous_cell_state) + F.squeeze(previous_cell_state, axis=0)) - return F.transpose(cell_state, axes=(1, 0, 2)), last_step_state + return cell_state, F.expand_dims(last_step_state, axis=0) @staticmethod def _inference_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: """Update SSRU cell at inference time""" - new_step_state = forget_rates * previous_cell_state + weighted_inputs # (batch, 1, input_depth) + new_step_state = forget_rates * previous_cell_state + weighted_inputs # (1, batch, input_depth) return new_step_state, new_step_state def hybrid_forward(self, F, inputs: mx.sym.Symbol, previous_states: mx.sym.Symbol, *args) -> Tuple: """ :param F: ndarray or Symbol - :param inputs: input data. Shape: (batch, max_length, input_depth). - :param previous_states: previous cell states. Shape: (batch, max_length, input_depth) - :return: cell output and new cell states. Both with shape (batch, max_length, input_depth). + :param inputs: input data. Shape: (max_length, batch, input_depth). + :param previous_states: previous cell states. Shape: (max_length, batch, input_depth) + :return: cell output and new cell states. Both with shape (max_length, batch, input_depth). """ forget_rates = self.forget_gate(inputs) weighted_inputs = (1 - forget_rates) * self.linear(inputs) diff --git a/sockeye/transformer.py b/sockeye/transformer.py index fdc5fc799..463889e37 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -107,7 +107,7 @@ def __init__(self, def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention - data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], None, bias) + data_self_att, _ = self.self_attention(self.pre_self_attention(data, None), None, None, bias) data = self.post_self_attention(data_self_att, data) # feed-forward @@ -158,15 +158,6 @@ def __init__(self, dropout=config.dropout_prepost, prefix=self.autoregr_layer.prefix + "post_", num_hidden=config.model_size) - # TODO (tdomhan): Remove with next major version bump. - # For backwards compatibility with versions prior to 2.1.17 we also store the layers under to previous - # attribute name. This way parameters can be loaded as either decoder.layers.0.autoregr_layer.ff_out.weight - # or decoder.layers.0.self_attention.ff_out.weight. Parameter deduplication makes sure parameters are stored - # and loaded once only. - if self.decoder_type == C.TRANSFORMER_TYPE: - self.self_attention = self.autoregr_layer - self.pre_self_attention = self.pre_autoregr_layer - self.post_self_attention = self.post_autoregr_layer self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, @@ -226,9 +217,8 @@ def hybrid_forward(self, F, source: mx.sym.Symbol, source_bias: mx.sym.Symbol, autoregr_states: mx.sym.Symbol, - enc_att_k: Optional[mx.sym.Symbol] = None, - enc_att_v: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, - mx.sym.Symbol]: + enc_att_kv: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, + mx.sym.Symbol]: target_autoregr, *new_autoregr_states = self.autoregr_layer(self.pre_autoregr_layer(target, None), autoregr_states, None, @@ -241,8 +231,7 @@ def hybrid_forward(self, F, source, None, source_bias, - enc_att_k, - enc_att_v) + enc_att_kv) target = self.post_enc_attention(target_enc_att, target) diff --git a/test/data/model_2.1.x/params.best b/test/data/model_2.1.x/params.best deleted file mode 100644 index 13e33ec10..000000000 Binary files a/test/data/model_2.1.x/params.best and /dev/null differ diff --git a/test/data/model_2.1.x/version b/test/data/model_2.1.x/version deleted file mode 100644 index 91dbb1711..000000000 --- a/test/data/model_2.1.x/version +++ /dev/null @@ -1 +0,0 @@ -2.1.16 \ No newline at end of file diff --git a/test/data/model_2.1.x/README.md b/test/data/model_2.2.x/README.md similarity index 100% rename from test/data/model_2.1.x/README.md rename to test/data/model_2.2.x/README.md diff --git a/test/data/model_2.1.x/config b/test/data/model_2.2.x/config similarity index 90% rename from test/data/model_2.1.x/config rename to test/data/model_2.2.x/config index 2d644b413..2eb204ec1 100644 --- a/test/data/model_2.1.x/config +++ b/test/data/model_2.2.x/config @@ -3,9 +3,9 @@ config_data: !DataConfig data_statistics: !DataStatistics average_len_target_per_bucket: - null - - 13.44736842105263 - - 20.568421052631578 - - 28.053672316384176 + - 13.418518518518514 + - 20.319148936170205 + - 28.016949152542363 - null - null - null @@ -80,8 +80,8 @@ config_data: !DataConfig num_sents: 1000 num_sents_per_bucket: - 0 - - 266 - - 380 + - 270 + - 376 - 354 - 0 - 0 @@ -91,8 +91,8 @@ config_data: !DataConfig - 0 - 0 - 0 - num_tokens_source: 21324 - num_tokens_target: 21324 + num_tokens_source: 21181 + num_tokens_target: 21181 num_unks_source: 0 num_unks_target: 0 size_vocab_source: 15 @@ -103,6 +103,7 @@ config_data: !DataConfig config_decoder: !TransformerConfig act_type: relu attention_heads: 2 + decoder_type: transformer depth_key_value: 16 dropout_act: 0.1 dropout_attention: 0.1 @@ -134,6 +135,7 @@ config_embed_target: !EmbeddingConfig config_encoder: !TransformerConfig act_type: relu attention_heads: 2 + decoder_type: transformer depth_key_value: 0 dropout_act: 0.1 dropout_attention: 0.1 @@ -150,7 +152,7 @@ config_encoder: !TransformerConfig use_lhuc: false config_length_task: null dtype: float32 -intgemm_custom_lib: /Volumes/CaseSensitive/Projects/CoreMT/sockeye-github/sockeye/libintgemm.so +intgemm_custom_lib: /workspace/sockeye/sockeye/libintgemm.so lhuc: false vocab_source_size: 15 vocab_target_size: 15 diff --git a/test/data/model_2.1.x/model_input b/test/data/model_2.2.x/model_input similarity index 100% rename from test/data/model_2.1.x/model_input rename to test/data/model_2.2.x/model_input diff --git a/test/data/model_2.2.x/params.best b/test/data/model_2.2.x/params.best new file mode 100644 index 000000000..6ab4f14fd Binary files /dev/null and b/test/data/model_2.2.x/params.best differ diff --git a/test/data/model_2.2.x/version b/test/data/model_2.2.x/version new file mode 100644 index 000000000..e3a4f1933 --- /dev/null +++ b/test/data/model_2.2.x/version @@ -0,0 +1 @@ +2.2.0 \ No newline at end of file diff --git a/test/data/model_2.1.x/vocab.src.0.json b/test/data/model_2.2.x/vocab.src.0.json similarity index 100% rename from test/data/model_2.1.x/vocab.src.0.json rename to test/data/model_2.2.x/vocab.src.0.json diff --git a/test/data/model_2.1.x/vocab.trg.0.json b/test/data/model_2.2.x/vocab.trg.0.json similarity index 100% rename from test/data/model_2.1.x/vocab.trg.0.json rename to test/data/model_2.2.x/vocab.trg.0.json diff --git a/test/integration/test_backwards_compatibility.py b/test/integration/test_backwards_compatibility.py index 7dd15b997..d663fd09d 100644 --- a/test/integration/test_backwards_compatibility.py +++ b/test/integration/test_backwards_compatibility.py @@ -43,8 +43,8 @@ def test_backwards_compatibility(): output_file = os.path.join(work_dir, "out") params = """{sockeye} --use-cpu --models {model} --input {input} --output {output} """.format( sockeye=sockeye.translate.__file__, - model="test/data/model_2.1.x", - input="test/data/model_2.1.x/model_input", + model="test/data/model_2.2.x", + input="test/data/model_2.2.x/model_input", output=output_file ) logger.info("Translating with params %s", params)