From 90144057efe4a73b26fb60a5c40bd180420d95bb Mon Sep 17 00:00:00 2001 From: Brenton Chu Date: Fri, 2 Oct 2020 07:18:33 -0700 Subject: [PATCH] Interleaved Multi-head Attention Operators (#884) Replaced batched dot product in multi-head attention with interleaved_matmul attention operators to improve performance. Also changes the batch-major data to time-major format while in the model to comply with the new operator requirements. --- CHANGELOG.md | 7 + MANIFEST.in | 12 +- sockeye/__init__.py | 2 +- sockeye/beam_search.py | 11 +- sockeye/decoder.py | 27 ++- sockeye/encoder.py | 2 + sockeye/layers.py | 212 +++++------------- sockeye/transformer.py | 19 +- test/data/model_2.1.x/params.best | Bin 33643 -> 0 bytes test/data/model_2.1.x/version | 1 - .../{model_2.1.x => model_2.2.x}/README.md | 0 test/data/{model_2.1.x => model_2.2.x}/config | 18 +- .../{model_2.1.x => model_2.2.x}/model_input | 0 test/data/model_2.2.x/params.best | Bin 0 -> 33554 bytes test/data/model_2.2.x/version | 1 + .../vocab.src.0.json | 0 .../vocab.trg.0.json | 0 .../test_backwards_compatibility.py | 4 +- 18 files changed, 108 insertions(+), 208 deletions(-) delete mode 100644 test/data/model_2.1.x/params.best delete mode 100644 test/data/model_2.1.x/version rename test/data/{model_2.1.x => model_2.2.x}/README.md (100%) rename test/data/{model_2.1.x => model_2.2.x}/config (90%) rename test/data/{model_2.1.x => model_2.2.x}/model_input (100%) create mode 100644 test/data/model_2.2.x/params.best create mode 100644 test/data/model_2.2.x/version rename test/data/{model_2.1.x => model_2.2.x}/vocab.src.0.json (100%) rename test/data/{model_2.1.x => model_2.2.x}/vocab.trg.0.json (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index f47f1c94d..980642d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [2.2.0] + +### Changed + +- Replaced multi-head attention with [interleaved_matmul_encdec](https://github.com/apache/incubator-mxnet/pull/16408) operators, which removes previously needed transposes and improves performance. + +- Beam search states and model layers now assume time-major format. ## [2.1.26] diff --git a/MANIFEST.in b/MANIFEST.in index e307a5fa7..7195a29e1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,12 +7,12 @@ include pylintrc include .flake8 include typechecked-files include test/data/config_with_missing_attributes.yaml -include test/data/model_2.1.x/config -include test/data/model_2.1.x/params.best -include test/data/model_2.1.x/model_input -include test/data/model_2.1.x/vocab* -include test/data/model_2.1.x/version -include test/data/model_2.1.x/README.md +include test/data/model_2.2.x/config +include test/data/model_2.2.x/params.best +include test/data/model_2.2.x/model_input +include test/data/model_2.2.x/vocab* +include test/data/model_2.2.x/version +include test/data/model_2.2.x/README.md include sockeye/git_version.py include *.bib recursive-include .github * diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 34f6d9f04..445bba797 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '2.1.26' +__version__ = '2.2.0' diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 3ce759fc5..0737aa169 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -417,17 +417,17 @@ def _repeat_states(states: List, beam_size: int, state_structure: List) -> List: assert len(states) == len(flat_structure), "Number of states do not match the defined state structure" for state, state_format in zip(states, flat_structure): if state_format == C.STEP_STATE or state_format == C.BIAS_STATE: + # Steps and source_bias have batch dimension on axis 0 repeat_axis = 0 elif state_format == C.DECODER_STATE or state_format == C.ENCODER_STATE: - # TODO: Change repeat axis to 1 when interleaved multihead attention is implemented - repeat_axis = 0 + # Decoder and encoder layer states have batch dimension on axis 1 + repeat_axis = 1 else: raise ValueError("Provided state format %s not recognized." % state_format) repeated_state = state.repeat(repeats=beam_size, axis=repeat_axis) repeated_states.append(repeated_state) return repeated_states - class SortStates(mx.gluon.HybridBlock): def __init__(self, state_structure, prefix): @@ -439,10 +439,11 @@ def hybrid_forward(self, F, best_hyp_indices, *states): assert len(states) == len(self.flat_structure), "Number of states do not match the defined state structure" for state, state_format in zip(states, self.flat_structure): if state_format == C.STEP_STATE or state_format == C.BIAS_STATE: + # Steps and source_bias have batch dimension on axis 0 sorted_state = F.take(state, best_hyp_indices) elif state_format == C.DECODER_STATE: - # TODO: Change take axis to 1 when interleaved multihead attention is implemented - sorted_state = F.take(state, best_hyp_indices) + # Decoder and encoder layer states have batch dimension on axis 1 + sorted_state = F.take(state, best_hyp_indices, axis=1) elif state_format == C.ENCODER_STATE: # No need for takes on encoder layer states sorted_state = state diff --git a/sockeye/decoder.py b/sockeye/decoder.py index cbb026923..9a28d0d1b 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -163,7 +163,7 @@ def state_structure(self) -> str: """ structure = '' if self.inference_only: - structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers * 2 + structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers else: structure += C.STEP_STATE + C.ENCODER_STATE + C.BIAS_STATE @@ -197,13 +197,11 @@ def init_state_from_encoder(self, states = [step, source_mask] for layer in self.layers: - encoder_attention_keys, encoder_attention_values = \ - layer.enc_attention.project_and_isolate_heads(mx.nd, encoder_outputs) - states.append(encoder_attention_keys) - states.append(encoder_attention_values) + enc_att_kv = layer.enc_attention.ff_kv(encoder_outputs) + states.append(mx.nd.transpose(enc_att_kv, axes=(1, 0, 2))) else: # NO encoder projection caching - states = [step, encoder_outputs, source_mask] + states = [step, mx.nd.transpose(encoder_outputs, axes=(1, 0, 2)), source_mask] batch_size = encoder_outputs.shape[0] dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size), @@ -271,7 +269,7 @@ def forward(self, step_input, states): if self.inference_only: # pass in cached encoder states - encoder_attention_keys_values = states[2:2 + self.config.num_layers * 2] + encoder_attention_keys_values = states[2:2 + self.config.num_layers] new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states else: encoder_outputs = states[1] @@ -288,14 +286,13 @@ def hybrid_forward(self, F, step_input, states): if self.inference_only: steps, source_mask, *other = states source_encoded = None # use constant pre-computed key value projections from the states - enc_att_kv = other[:self.config.num_layers * 2] - enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)] - autoregr_states = other[self.config.num_layers * 2:] + enc_att_kv = other[:self.config.num_layers] + autoregr_states = other[self.config.num_layers:] else: if any(layer.needs_mask for layer in self.layers): mask = self.autoregressive_bias(step_input) # mask: (1, length, length) steps, source_encoded, source_mask, *autoregr_states = states - enc_att_kv = [(None, None) for _ in range(self.config.num_layers)] + enc_att_kv = [None for _ in range(self.config.num_layers)] if any(layer.num_state_tensors > 1 for layer in self.layers): # separates autoregressive states by layer @@ -307,23 +304,25 @@ def hybrid_forward(self, F, step_input, states): # target: (batch_size, length, model_size) target = self.pos_embedding(step_input, steps) + # (length, batch_size, model_size) + target = F.transpose(target, axes=(1, 0, 2)) if self.config.dropout_prepost > 0.0: target = F.Dropout(data=target, p=self.config.dropout_prepost) new_autoregr_states = [] - for layer, layer_autoregr_state, (enc_att_k, enc_att_v) in zip(self.layers, autoregr_states, enc_att_kv): + for layer, layer_autoregr_state, layer_enc_att_kv in zip(self.layers, autoregr_states, enc_att_kv): target, new_layer_autoregr_state = layer(target, mask, source_encoded, source_mask, layer_autoregr_state, - enc_att_k, enc_att_v) + layer_enc_att_kv) new_autoregr_states += [*new_layer_autoregr_state] - # NOTE: the list expansion is needed in order to handle both a tuple (of Symbols) and a Symbol as a new state target = self.final_process(target, None) + target = F.transpose(target, axes=(1, 0, 2)) return target, new_autoregr_states diff --git a/sockeye/encoder.py b/sockeye/encoder.py index ec4ea41ea..ffb3cca09 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -330,11 +330,13 @@ def hybrid_forward(self, F, data, valid_length): # (batch_size * heads, 1, seq_len) bias = F.expand_dims(self.valid_length_mask(data, valid_length), axis=1) + data = F.transpose(data, axes=(1, 0, 2)) for block in self.layers: data = block(data, bias) data = self.final_process(data, None) + data = F.transpose(data, axes=(1, 0, 2)) return data, valid_length def get_num_hidden(self) -> int: diff --git a/sockeye/layers.py b/sockeye/layers.py index 52d1ff3a9..b76191bfb 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -257,38 +257,6 @@ def hybrid_forward(self, F, source_encoded, source_encoded_length): return F.squeeze(data) -def split_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sym.Symbol: - """ - Returns a symbol with heads as second dimension and channel depth / number of heads as last dimension. - - :param x: Symbol of shape (batch, length, depth). - :param depth_per_head: Depth per head. - :param heads: Number of heads. - :return: Symbol of shape (batch, heads, length, depth_per_heads). - """ - # (batch, length, heads, depth_per_head) - x = F.reshape(x, shape=(0, -1, heads, depth_per_head)) - # (batch, heads, length, depth/heads) - return F.transpose(x, axes=(0, 2, 1, 3)) - - -def combine_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sym.Symbol: - """ - Returns a symbol with both batch & length, and head & depth dimensions combined. - - :param x: Symbol of shape (batch * heads, length, depth_per_head). - :param depth_per_head: Depth per head. - :param heads: Number of heads. - :return: Symbol of shape (batch, length, depth). - """ - # (batch, heads, length, depth_per_head) - x = F.reshape(x, shape=(-4, -1, heads, 0, depth_per_head)) - # (batch, length, heads, depth_per_head) - x = F.transpose(x, axes=(0, 2, 1, 3)) - # (batch, length, depth) - return F.reshape(x, shape=(-1, 0, depth_per_head * heads)) - - def broadcast_to_heads(F, x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads: bool = True) -> mx.sym.Symbol: """ Broadcasts batch-major input of shape (batch, d1 ... dn-1) to (batch*heads, d1 ... dn-1). @@ -325,9 +293,10 @@ def cast(self, dtype): self._dtype = dtype super().cast(dtype) - def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None): - # (n, lq, lk) - logits = F.batch_dot(lhs=queries, rhs=keys, transpose_b=True) + def hybrid_forward(self, F, queries, key_values, heads, lengths=None, bias=None): + + # (n*h, lq, lk) + logits = F.contrib.interleaved_matmul_encdec_qk(queries, key_values, heads=heads) # TODO(fhieber): consider softmax with length argument once available. # TODO(fhieber: Also see https://github.com/dmlc/gluon-nlp/pull/910 @@ -347,9 +316,11 @@ def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None): probs = F.softmax(logits, axis=-1) probs = F.Dropout(probs, p=self.dropout) if self.dropout > 0.0 else probs - - # (n, lq, lk) x (n, lk, dv) -> (n, lq, dv) - return F.batch_dot(lhs=probs, rhs=values) + + # key_values: (lk, n, dv * 2) + # probs: (n*h, lq, lk) + # result: (n, lq, dv) + return F.contrib.interleaved_matmul_encdec_valatt(key_values, probs, heads=heads) class MultiHeadAttentionBase(mx.gluon.HybridBlock): @@ -385,35 +356,25 @@ def __init__(self, def _attend(self, F, queries: mx.sym.Symbol, - keys: mx.sym.Symbol, - values: mx.sym.Symbol, + key_values: mx.sym.Symbol, lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: """ Returns context vectors of multi-head dot attention. - :param queries: Query tensor. Shape: (batch_size, heads, query_max_length, depth_per_head). - :param keys: Keys. Shape: (batch_size, heads, memory_max_length, depth_per_head). - :param values: Values. Shape: (batch_size, heads, memory_max_length, depth_per_head). + :param queries: Query tensor. Shape: (query_max_length, batch_size, depth). + :param key_values: Keys. Shape: (memory_max_length, batch_size, depth * 2). :param lengths: Optional lengths of keys. Shape: (batch_size,). :param bias: Optional 3d bias. :return: Context vectors. Shape: (batch_size, query_max_length, output_depth). """ - # fold head dimension into batch dimension - # (batch*heads, length, depth/heads) - queries = F.reshape(queries, shape=(-3, -1, self.depth_per_head)) - keys = F.reshape(keys, shape=(-3, -1, self.depth_per_head)) - values = F.reshape(values, shape=(-3, -1, self.depth_per_head)) lengths = broadcast_to_heads(F, lengths, self.heads, ndim=1, fold_heads=True) if lengths is not None else lengths - # (batch*heads, query_max_length, depth_per_head) - contexts = self.dot_att(queries, keys, values, lengths, bias) + # (query_max_length, batch, depth) + contexts = self.dot_att(queries, key_values, self.heads, lengths, bias) - # (batch, query_max_length, depth) - contexts = combine_heads(F, contexts, self.depth_per_head, self.heads) - - # contexts: (batch, query_max_length, output_depth) + # (query_max_length, batch, output_depth) contexts = self.ff_out(contexts) return contexts @@ -489,7 +450,7 @@ def prefix(self) -> str: @property def num_state_tensors(self) -> int: """ Number of state tensors returned by the layer """ - return 2 + return 1 @property def needs_mask(self) -> bool: @@ -501,12 +462,12 @@ def get_state_shape(self, batch_size: int) -> Tuple: :param batch_size: current batch size :return: dimensions of each output state (assuming all of them have the same shape) """ - # shape: (batch, heads, length, depth_per_head) - return batch_size, self.heads, 1, self.depth_out // self.heads + # shape: (length, batch, key_depth + value_depth) + return 1, batch_size, self.depth_out * 2 def hybrid_forward(self, F, inputs: mx.sym.Symbol, - previous_states: List[mx.sym.Symbol], + previous_states: Optional[mx.sym.Symbol] = None, input_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, *args): # mypy: ignore @@ -524,42 +485,17 @@ def hybrid_forward(self, F, Shape: 2 * (batch, max_length+1, depth_att). :return: Symbol of shape (batch, max_length, output_depth). """ - # combined: (batch, max_length, depth * 3) - combined = self.ff_in(inputs) - # split into query, keys and values - # (batch, max_length, depth) - # pylint: disable=unbalanced-tuple-unpacking - queries, keys, values = F.split(combined, num_outputs=3, axis=2) - - # scale by sqrt(depth_per_head) - queries = queries * (self.depth_per_head ** -0.5) - # (batch, heads, length, depth/heads) - queries = split_heads(F, queries, self.depth_per_head, self.heads) - keys = split_heads(F, keys, self.depth_per_head, self.heads) - values = split_heads(F, values, self.depth_per_head, self.heads) - - updated_keys = keys - previous_keys, previous_values = previous_states - if previous_keys is not None: - updated_keys = F.concat(previous_keys, keys, dim=2) - keys = _remove_first_step(F, updated_keys) + proj = self.ff_in(inputs) + queries, kv_1, kv_2 = F.split(proj, num_outputs=3, axis=2) + states = F.concat(kv_1, kv_2, dim=2) - updated_values = values - if previous_values is not None: - updated_values = F.concat(previous_values, values, dim=2) - values = _remove_first_step(F, updated_values) + updated_states = states + if previous_states is not None: + updated_states = F.concat(previous_states, states, dim=0) + states = F.slice(updated_states, begin=(1, None, None), end=(None, None, None)) - return self._attend(F, queries, keys, values, lengths=input_lengths, bias=bias), updated_keys, updated_values - - -def _remove_first_step(F, data): - """ - :param F: MXNet namespace. - :param data: Input data. Shape: (batch, heads, length, num_hidden). - :return: Output data. Shape: (batch, heads, length[1:], num_hidden - """ - return F.slice(data, begin=(None, None, 1, None), end=(None, None, None, None)) + return self._attend(F, queries, states, lengths=input_lengths, bias=bias), updated_states class MultiHeadAttention(MultiHeadAttentionBase): @@ -587,56 +523,32 @@ def __init__(self, with self.name_scope(): self.ff_q = quantization.QuantizableDense(in_units=depth_out, units=depth_att, flatten=False, use_bias=False, prefix='q2h_', dtype=dtype) - self.ff_k = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='k2h_', dtype=dtype) - self.ff_v = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='v2h_', dtype=dtype) - - def project_and_isolate_heads(self, F, memory: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: - """ - Projects memory into keys and values, and separates attention heads dimension. - - :param memory: Memory tensor. Shape: (batch, memory_max_length, input_depth). - :return: Symbol of shape (batch, heads, memory_max_length, depth_per_head). - """ - keys = self.ff_k(memory) - values = self.ff_v(memory) - keys = split_heads(F, keys, depth_per_head=self.depth_per_head, heads=self.heads) - values = split_heads(F, values, depth_per_head=self.depth_per_head, heads=self.heads) - return keys, values + self.ff_kv = quantization.QuantizableDense(in_units=depth_key_value, units=2*depth_att, flatten=False, use_bias=False, prefix='kv2h_', dtype=dtype) def hybrid_forward(self, F, queries: mx.sym.Symbol, memory: mx.sym.Symbol, memory_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, - projected_memory_keys: Optional[mx.sym.Symbol] = None, - projected_memory_values: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: # mypy: ignore + projected_memory_kv: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: # mypy: ignore """ Computes multi-head attention for queries given a memory tensor. If sequence lengths are provided, they will be used to mask the attention scores. A bias mask may also be used to mask the attention scores. - Returns a symbol of shape (batch, max_length, output_depth). + Returns a symbol of shape (max_length, batch, output_depth). - :param queries: Query tensor. Shape: (batch, query_max_length, input_depth). - :param memory: Memory data to attend to. Shape: (batch, memory_max_length, input_depth). + :param queries: Query tensor. Shape: (query_max_length, batch, input_depth). + :param memory: Memory data to attend to. Shape: (memory_max_length, batch, input_depth). :param memory_lengths: Optional lengths of memory to mask attention scores. Shape: (batch, 1). :param bias: Optional 3d bias tensor to mask attention scores. - :param projected_memory_keys: Optional previously projected memory keys. - :param projected_memory_values: Optional previously projected memory values. - :return: Symbol of shape (batch, query_seq_len, output_depth). + :param projected_memory_kv: Optional previously projected memory keys and values. + :return: Symbol of shape (query_seq_len, batch, output_depth). """ - # (batch, query_max_length, depth) - queries = self.ff_q(queries) - # scale by sqrt(depth_per_head) - queries = queries * (self.depth_per_head ** -0.5) - # (batch, heads, length, depth/heads) - queries = split_heads(F, queries, self.depth_per_head, self.heads) - if projected_memory_keys is not None and projected_memory_values is not None: - keys, values = projected_memory_keys, projected_memory_values - else: - keys, values = self.project_and_isolate_heads(F, memory) + queries = self.ff_q(queries) + kv = projected_memory_kv if projected_memory_kv is not None else self.ff_kv(memory) - return self._attend(F, queries, keys, values, bias=bias, lengths=memory_lengths) + return self._attend(F, queries, kv, bias=bias, lengths=memory_lengths) class PlainDotAttention(mx.gluon.HybridBlock): @@ -652,14 +564,14 @@ def hybrid_forward(self, F, queries, memory, memory_lengths): """ Returns a symbol of shape (batch, max_length, output_depth). - :param queries: Symbol of shape (batch, queries_max_length, input_depth). - :param memory: Symbol of shape (batch, memory_max_length, input_depth). + :param queries: Symbol of shape (queries_max_length, batch, input_depth). + :param memory: Symbol of shape (memory_max_length, batch, input_depth). :param memory_lengths: Symbol of shape (batch, 1). - :return: Symbol of shape (batch, queries_max_length, output_depth). + :return: Symbol of shape (queries_max_length, batch, output_depth). """ - # (batch*heads, queries_max_length, depth_per_head) - return self.dot_att(queries, memory, memory, memory_lengths, None) + # (queries_max_length, batch, output_depth) + return self.dot_att(queries, memory, 1, memory_lengths, None) class ProjectedDotAttention(mx.gluon.HybridBlock): @@ -688,26 +600,19 @@ def hybrid_forward(self, F, """ Apply project, apply dot attention and return new context vectors. - :param queries: Symbol of shape (batch, queries_max_length, input_num_hidden). - :param memory: Symbol of shape (batch, memory_max_length, input_num_hidden). + :param queries: Symbol of shape (queries_max_length, batch, input_num_hidden). + :param memory: Symbol of shape (memory_max_length, batch, input_num_hidden). :param memory_lengths: Symbol of shape (batch, 1). - :return: Symbol of shape (batch, queries_max_length, num_hidden). + :return: Symbol of shape (queries_max_length, batch, num_hidden). """ - # (batch, memory_max_length, num_hidden * 2) + # (memory_max_length, batch, num_hidden * 2) combined = self.kv2h(memory) - # split into keys and values - # pylint: disable=unbalanced-tuple-unpacking - keys, values = F.split(data=combined, num_outputs=2, axis=2) - - # (batch, queries_max_length, num_hidden) + # (queries_max_length, batch, num_hidden) queries = self.q2h(queries) - # scale by sqrt(num_hidden) - queries = queries * (self.num_hidden ** -0.5) - - # (batch, queries_max_length, num_hidden) - contexts = self.dot_att(queries, keys, values, memory_lengths, None) + # (queries_max_length, batch, num_hidden) + contexts = self.dot_att(queries, combined, 1, memory_lengths, None) return contexts @@ -869,10 +774,7 @@ def get_state_shape(self, batch_size: int) -> Tuple: :param batch_size: current batch size :return: dimensions of each output state (assuming all of them have the same shape) """ - if self.inference_only: - return batch_size, 1, self.model_size - else: - return batch_size, self.model_size + return 1, batch_size, self.model_size @staticmethod def _training_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: @@ -890,28 +792,26 @@ def _time_step_update(step_input_and_forget_rate, previous_step_state) -> Tuple: current_step_state = forget_rate * previous_step_state + step_input return current_step_state, current_step_state - weighted_inputs = F.transpose(weighted_inputs, axes=(1, 0, 2)) # (max_length, batch, input_depth) - forget_rates = F.transpose(forget_rates, axes=(1, 0, 2)) # (max_length, batch, input_depth) # (max_length, batch, input_depth), (batch, input_depth) cell_state, last_step_state = F.contrib.foreach(_time_step_update, [weighted_inputs, forget_rates], - previous_cell_state) + F.squeeze(previous_cell_state, axis=0)) - return F.transpose(cell_state, axes=(1, 0, 2)), last_step_state + return cell_state, F.expand_dims(last_step_state, axis=0) @staticmethod def _inference_cell_state_transform(F, previous_cell_state, weighted_inputs, forget_rates) -> Tuple: """Update SSRU cell at inference time""" - new_step_state = forget_rates * previous_cell_state + weighted_inputs # (batch, 1, input_depth) + new_step_state = forget_rates * previous_cell_state + weighted_inputs # (1, batch, input_depth) return new_step_state, new_step_state def hybrid_forward(self, F, inputs: mx.sym.Symbol, previous_states: mx.sym.Symbol, *args) -> Tuple: """ :param F: ndarray or Symbol - :param inputs: input data. Shape: (batch, max_length, input_depth). - :param previous_states: previous cell states. Shape: (batch, max_length, input_depth) - :return: cell output and new cell states. Both with shape (batch, max_length, input_depth). + :param inputs: input data. Shape: (max_length, batch, input_depth). + :param previous_states: previous cell states. Shape: (max_length, batch, input_depth) + :return: cell output and new cell states. Both with shape (max_length, batch, input_depth). """ forget_rates = self.forget_gate(inputs) weighted_inputs = (1 - forget_rates) * self.linear(inputs) diff --git a/sockeye/transformer.py b/sockeye/transformer.py index fdc5fc799..463889e37 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -107,7 +107,7 @@ def __init__(self, def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention - data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], None, bias) + data_self_att, _ = self.self_attention(self.pre_self_attention(data, None), None, None, bias) data = self.post_self_attention(data_self_att, data) # feed-forward @@ -158,15 +158,6 @@ def __init__(self, dropout=config.dropout_prepost, prefix=self.autoregr_layer.prefix + "post_", num_hidden=config.model_size) - # TODO (tdomhan): Remove with next major version bump. - # For backwards compatibility with versions prior to 2.1.17 we also store the layers under to previous - # attribute name. This way parameters can be loaded as either decoder.layers.0.autoregr_layer.ff_out.weight - # or decoder.layers.0.self_attention.ff_out.weight. Parameter deduplication makes sure parameters are stored - # and loaded once only. - if self.decoder_type == C.TRANSFORMER_TYPE: - self.self_attention = self.autoregr_layer - self.pre_self_attention = self.pre_autoregr_layer - self.post_self_attention = self.post_autoregr_layer self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, @@ -226,9 +217,8 @@ def hybrid_forward(self, F, source: mx.sym.Symbol, source_bias: mx.sym.Symbol, autoregr_states: mx.sym.Symbol, - enc_att_k: Optional[mx.sym.Symbol] = None, - enc_att_v: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, - mx.sym.Symbol]: + enc_att_kv: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, + mx.sym.Symbol]: target_autoregr, *new_autoregr_states = self.autoregr_layer(self.pre_autoregr_layer(target, None), autoregr_states, None, @@ -241,8 +231,7 @@ def hybrid_forward(self, F, source, None, source_bias, - enc_att_k, - enc_att_v) + enc_att_kv) target = self.post_enc_attention(target_enc_att, target) diff --git a/test/data/model_2.1.x/params.best b/test/data/model_2.1.x/params.best deleted file mode 100644 index 13e33ec10ca00e297ca39b032edf0d260bb827cf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 33643 zcmeFZc~p&W+y7rGQ-hKzNh!&YLZ!5?^EjepOd=#?NT?8Lo_0c!Mir9Kgp4IL?ejd7 z%w#AsW{OOep@>t5@=zw7y~_4Pc@pT9pot#z$k?e*IG+Sfj>V}74K<>mgbAN_y- zb>&CO_kaD`Nj!S|`+w^He*CXUocem^(lHCVoG20+(^c50k+HIamKEimeUwQ2Ds>Xv zHwHdj@5JIO&EePxPgYZrz;-EbU}@twGJ9QjrZ@f!+wf5Z)-Bm5i(jYBj`5*!MIg2Bq5^&Z)v?N7-Mdw;g=Xd1I68`+e_G0e`|mo4d$M(cXG$g1q5D{Mn*Bc_Gzr7YB6(M`IFX;kmE{s3t36jHn#kBS0<$NlBt?U z(0Uxmb}t;rZh41TnVzU3t~#0I==0wsG09Q5TN)@_ThfE=nQ@9$HGQKe7Ik9X4;*3L zw70O@+^*$=r*EMLKT24KYo<^cp&`>)SwK7FEoqbLW;$U1ZOOB%YXwQ|6sp~=Gc&b4 z$(%c~nM^i_O?E@JTd7FaI&}%%HyBxVm(eV2V=NtVb04{DQB0pd-$)17rCJ}zR1jWo zZladwa@p^tS*-1bI{W%%I5S;0h?ytulwH^qNUOL$OT1>yI@j$aJsfkW;<^W<`VphM zCa#wG*%k@|tdwYO=x;JLprF{_Mf2 z^HgrPJh^@_lk8PKD@pM@Lb7vq3o1=b!kC|v>GTyzti$9u3;%SQ~WVi1FDQ+hVD zY~974PF1Btp1RPSmm^u#w|z7pIC-1@&dOQy6|Zl+peR} zzQkp+UE9X7Zklu00Nn|)akpdXGG7a(Inav^G*D%=!JEluo=lRQ73qQUqmsUE^}?1u z%V@gOE;cJGja8$EPK< zEWSvh7juI9v)kg!eX2eFXL%IuqO9#gEF$bS2q%R((1WkGx(Q;0QY z7E6v1lM~PAWcSxJ??HjEtXE5!Yf5_A*N(F^+iC)vE7-B_@)j)5&z`kSJjwp+eZJ(s zxzGRo)PM63iE?s(KXSe1A&lD(0shy?!j1A|-aP|B4;&?HdTz2_d@XOBjZ=5@ny~?Q zmqeiOREniPrFiU*6x+q)k9ZVcT7}xqeet``G`PHOBZ1l+;R(4Z@$9J}S$xgV+Dzjv z`w%(<^InDE_b@5G`YgqQ-%_j>*B!9$hmaHv~t) zrROHe!fScfqvja#LX$mk=&Cn*wuWQsH7Qp9kYb3)ei0A#Fn=x;Ka*Wx8w?Z%LEG8s zWaE~NLjAoQ$!NG~Z7d{1t?w|IP5u(xSRIC1C&c&sCPfpGeJ>tyJrbEikS}`1siTkH zc(}dSpV+U?5awr8Nt6oPtxInBqYvh?F;iya%ax(nGggY`tx~)rUV|4RZ$D}%`n;Kr z8{QAbnKesb3?!0+v(5;Y7Qd2w(9n?3KQ>tIr^Sn|+v2;5VBEM^il1Ag=ppiTBA;d- zj^Uj(@nEhc&KeR4PDgS`g4Inyv!{aKe9chu^;a+UI@yUYUN8nbJq*MZCQ|JBUW$$4 zHM}bFW8`b#oo*WFY@LiIm5G2oOGt~(Tfy2-Qz+kTExC)HJo~jDzdJ}5^Nakk@@oVJ zyprO4k(Y^lyPYzRT9+^T+d5i)V=nLbB;|iwM+^T*`JnOnZ1|AnxIN|s_|`ll-W$}2 zmzBA&<(s?2S1DN%m$9B#kT=k9WF5ZQ5P=KsOR<~yo$V3d|Ng>md`Pl3*V(fcjq|DK z%i4&Jhc;Qi$4+=t5G1+P*#+k~PUCISrEvbr8ay*J0`J_C;^Yn~#)!PWUL8#DHJeu) z4M3Yc*WngZ1kbrciT6)u!SGD9#An7T%zU8EGupPo`J-MaR~L?1m!%jWt|4EM_uUiRABnvXv%|ocwNS4$osz49bBv!9%0);MiFUj$2M3W)=RzrnMQ8Hz5=FH~rtjk@_UEO6HFJ^}=z= zSt(|9KmhhEbxssm>Y9tTm zTdfeyCIq8wmK2XRi)%;Zf7Yzznf(HB=zC-AH)0ezt=a&=lQYQeomGNwLA&Jl$Omxn z(H6LOLE!F8bh-IKINw%^dQDQ)5qYbaE^PT>DF#_+Vfz;g44fAZzBG%>yL4TMjg}XJ zcsR~0nN4|L2VU6A9P_#dp`)1;w}>{RRkQ(JEp4&j$y$8ykVK0>htCTdh|N0HZ#oM8 zqwO&~A|9%>&X6>hS3;nMhA?kU06!EokB3>T=W5;bvAm1O|B1kvucf$F6J^JAUstJ#=-l!eV4PVx>qW8SE}@O%Ab;_&vDuyC)QkYMbE z`(Cf%vo1-w+g3HyzrF!oFNl8dnRpFF8|W}x74z#1c$t(#QL_+>RYqoP0# zUNsbsB+kLpMRpvMr2L(%3-*)WfUF<_H`GfpPvpmmcC7HF3-7aa5c?xpgzazg;H~pr zQZ~2?X=xuW6i3gOb^p_a2VV>0XRGD#uBk7EW<=nu$Kv%9`ETpHVDOqCKIL8%n=oc6 z-eE@|@xmigEGQF8qftVmWIi{t3}k&G19?yB7g+4F9+S63px#5#UWt6A^&>XhS;`-$ z7qMr}%h4g}BpkGOL5@22AQ5vcgvBl^x%}W%a-iIYFV}wsGA{yS z?@IBU$iHKQx#gb;d|pl$zOvXG4;z+(Mz0nUu(>bkdJ2V+pKWn$y&?A2na`_bmq7mb zT6~WYXjmh~%_6^_?t&Jk#7nXV@U(Mlaq3$NZc%N-VXros_RLN=BlW|plU?y~*i^oz zyac|<*5C@G2rLkN&_t1+Rvv&?Kdff!HAeB=g!QO)xDvX3>>ygFbxEAsRAKd&t#Dd3 z7>n+WNB1+`n<;aSg7B>Iw@@J3q=hleU)yu}SRHgqm@6s|(m zoN!!yUW%VYADX3*EDSspjAQn0;FB6ck@k8BW0cguf4Dh0*XAyii+-jhV;RU?GnrlS zSg6iiiP_QNxT8XfheiHY?R=cp5P@D8!mBL8@kUWSd>^L)f5uyp9|oR6WN2sHyMWP0 z59hIjsy=Xf(+W)X49CM%ii<>k&!ZLS^Uw>oEsNmm?ZxlK;Uz3x+Y7?y3uMf!wZaD;hDkiJWrxz=zYq@pzdOHAMc6<}YFwjoAJ}$~*Uw;@M}7urj+ZG_SKE zt0FfDMcpEJ_LE`!&iAP@y~bnY4q1-jBgGh`Sd5*;7+~%(JAU0v9`h8Xd}gi`XRZAJ z%J=%i?5%br^F)vk`OJVT$uHtZ#}tu2J6(uYPj?Jf3&+i;rMN-l^@sWKXLc^27!}Hg zi9YRF+eet(Z6F-kKbhQm94^>3rbFSl5U%wLAbsWuA^F)-ywx0r^9n^)8qHJ&5f`*{mQ@7PIK40FRT+r!Z7km!R&{`A?apy}+!Lwn5OnQ0N2?In-8cZS2J zpDyI$)NMjqFK>SN_AeIlUKKSo>uA%ng;?YnhCgzocwLNvZm9I&-WNjntD1>?O-49+ z?^D3>3MQ~ia}jy;e7kVa!I{5mbYPNDeVlYcnO)zp0IS7VG5CNOJBoa3i|nl6Hks{D(H%K24UE^SBiZ_-fgoI zA8=|FFWfVTXEuc3osKSOy4)OgjQ1d&)RTqm)vmbJ=Prb8x5t^Ar?IoIT(H-lP_#=I zb2;(e_%-t>xU@@n*>@FwVrVeVHtddHQ!F9a#gp_tlp-8kv;sF=9f|bh9Gqb2&1$0O zAwCSnQ2UQaCV?iMzAOyu@c zL($r24Nkn9%3?~K@lZl2w#Q3xx5yi)?`2`pzsTS6u(`K4@ONDz@NapTbao)WCOm|j zTdi4qg%7Sj)DtJqu!E4IU@}KxpYYmnGh3Xx5<`naF+p0y{A!&r)Fu=MM2mUg7cm#J zCY+a=;-%q7=-O+mv4r))1EG_k=hbksY|jDVo0lSwx#9USDqOzDB+Tg_g0ZWl_&|)o9eTZ|6ADZq zt5+|acGwB~x9a2MBa6V;ES_YQpAzoQh`>1*#uwJi#JJZz`O4u-eQnGn#~I5f6L>Yf;OP-g_o>2(;9VKCt+&+5WLXg4klBQN&een!P;sc zdo$64Zva!wdZa7HLKCoGN-(~3m13w^2QX;c2oJWI@!KP(;eqQD(C5Zb%r#gEgPc>y z(a17EmKn&uhmGZLW!kt`YY3kgV1wDqgVB7h7(a-7lg&u3=I+id&(24?4FubZhvSE- zUeJAU8adXN3VMM8pE~wEQ#{@oHAforit~WmEP~PBNs6&z4iYyejBh*{#$_5S(Ce)w z>hCkgJ?q!NJI`I@*a>=__+b^=_4`IAY zZykn?8iUWJCU|_;I^cfkQzzaP4&+ua<8cMgG0Y5PfL_p8Fs-3$1vUCVKy+#+)B z(Ime1w}ogwf{{$J!}%U&|=%UW!|7!Q0&(ZaJ2EU;`?IGo&{MRsLf7hK&u`1yX} zShiZA-EU9l7N+Kyq#A^K%)}g2%)!EaB6!2`q42jl(3zeuD2)umztw@}*$m>B&6}WT zY7lzNR>2ou<8VXqMtE~Nk8HVlTWF}Y$2s-OFvmTZ4bYj(+oML|oxDKQ8z#k#@1z*L z_yA0`x5Sc~b+{}{36~g45GQW}KYoaWtf&=w4cvgvvN3oj?>O7E-j%B!G{HaX0`Z!G zSUVDPK-ud+oPEIqWvwgFqw_E5`BuQw_cue^m3*@H*FC{EMheT%J%=?<9y7TFH-5Ov z7`ICTarr5y0J$(; zs$<6dbo%q2+AH|s)}dHl9e^I1Qgjyiiz`=ii>#%%L0-VgrMID7egbycx(&v@JxOdT z>jhoED)!{XCzf6@l9z7tdj?XxRe79^hchxjRkHrCK&{eEyi2SWz>SWKa$DmJpLD(ei6Uq?aPM@8id=L{BhsU2viaI zZ;zMb!L5Y-t&aG{^A+2l63YKpN9^eFfXt5l${sJtfjLQA;QfH9c+)W%+y|7Ay3_B4 z`h*DnX3-Mw>j(1n19fm{us`1W7=iU-4&HbBBz~ybnvW;TL6>?%+_7n>{$vM) z4v`TytywtYrpNogG2yO>!ThXlFz?z|8#{wPK6n>_b|Sy8(wDDQTgAVMJtaYBI(!@J zfS)~6A!ZaKIWIm6iqY9{-t0aL?G(WuJPzRpm9@~WlRsV)bB6qvV*G6o%3cqT;CI`u z3R+W)VUTPFuK$__G1e6%I{1qaZW4rD_q5U%@1^|0{4hTMXCHiaL5y?i#p@^XDylTWUC)W+>gMu zFT@(kzx#pz>izw{`hce$hGO))KTssM6xVbrg3pWepxtZ*$jdE;E;e>h+hmB|k6u70 zi7Hmi-wV%w8iUc))i8095_DhKPX4FYH>qNR<&0&twEqs`v-6bjN&mD=vO|~kZhc4& zTlxvk2hHhXsV)Vdp|Ci=h(4WAM+48kCN@pGge-eg`k&YTS3mPV=l`F_|Ju*Af7?QK z74;UL_R1jF(>KeStbNGm=zPg&8wEaOv%Dbbb`aua7l@X#J~6e)qblkd!rjjM;LiFo zvLIKHIaRD-S3YFOzQhitWuHofO}}rK>sqT~`=JopFVCM`^2=l+Cdfm`ft2!7TV^nK zJrx+ddo1zvo+fdQ8^FeG`cP)seLJz9^wGNOEM-}Aa*;qy$FVV~ZX~eGUNTH!H_32k zvYy?Kkt6v71nq8~bl93OAwX#xeca!Olr*MLt42lAX>qJjI`lNHKI$OLx%-29=57}{ zD$-@aft}d38z0%~iw9^%#ZI>BJzD48CQ#T~O}zthVN<8$RDEW?aLsi%b+g8@T^V16 zk&PZyFC$Xo@S`8Sx^My8H%i8w)nlN_<2HG9UK@^Dslg|OM0oyZF|q5nhwjM(8hP?= zd3t%faK7tWneBwFG-0X($&I!m_d-rm%}c;)Q#Ih#@CMp(B!`Za9;MeFHn8nqQdlHT zrgMGMgkd-JX`1qM5>Cd@FUmGlt*ew6q*%z(&Yd8>FTG&lAt|$+nMlLbD(Kn8$7x^J z-{EJ zwf&ONc*&B^AMg$yjk-&GPO3t|l@V}Z_hPd2Qk5{zE1YOgh=qYs$z)&#knm75I6gR( z&ZK=wAH4>0)0D&P%tE^TpgkMdnkD`WJq`EX~?2J zrV~q#U9?c+3Jn;1jO6?r%_2-|p`*Jzm*2!_(EL}lRsIptYrAFD-k3zU`Ya$ePru8u z{Zz>Gpg@WSXK2tjDY5uqO}{xEFONu>1drnCsb1hlh;peYQ%SX9BRwSKvsn(?ZoZOE zleyDZbLC|x3Z{?&{i1}B^gw!}avwPuI*hoR>CxMvx%A@OvE+N16>&8yrIp4fscO+H zc5t1HUBBo+Umx#D+aIV1gU194JFT9S8&t{5ws%y?3Su7%hwiK<+dmsX$huNi<`&QH zT>C*Gbz9l$n_A?JWjRD1ETrxO8)PdMnhL`Ar;@==Yspd5y>xSBzMwEHK*%#VN=oLt zg72%m-TW0wM-_qf&rdAsC!rS)SW&IIR9<7F%+DF^rd|6-Ls#!j zY|8j(n6_w>%qveq{0(#XnvvJ(H;11>eduR8{m6bc{rzY%v&Un)JmmmPm}QAJ5&d!9 zm*Z^OH*?niqY5q$Pi1qr`mjH}6L`36I-mVKjxWAr%Q}juz>4Z-cEf83dVClPJz7Ta zSKjjkwJTAW=KWo;S{uR&gMD!Oy=Y+SWo-4IHumoI4SsfW3Fchsf+H`>pmuqj>~zQ| zD4Si#j)mTW(HAzuV~0ab_j@P4I68@WT$QkQ-$$}@k&*1mn>8?ea|XAH2*$#B4BXdS zLelR3TwXU0lAnJh7dpNP!w#>+4`NO8mf36eV903Z7w3%Ehv(43mg@{^4>P^;emw5x zdv@yBA#epFxH#o^`LWWCOfKG$MEgF5!r-6iFyDb+G)|^Pv*v=~!k=vD_UW+u-PUrt z;}C1il0&QcBP(tgEa!5q%6yfn5}Rb0NW#3ALZ8l?!D3Vi9A7m7H!AOkEob9ML()uU zW$pt;t2Llt?g!rUMk(5itY%^ByK#BtHg-?DnnfzQvt9XL>0xUHRQK9Z=HCr)#cV~G zU)_U!PI00AAN9p9MN&9g-o0YyoLf9IIv&4l>w*t;717^EALlCcr?F*S!2DJs3|SY3 z)FqfVbQ#4?f9s2%Z*^zVhNY|}(TSuNlwpWflC77V5Ct;(8zh#;~7+B_i0?8vYG#m1r^;~#>^&R({ ztGz1VK`wGQLo=RiTY3h1c~*jRZH%mb&klSNAa9d0Q&raI{BAzpt3O`#sfAtIA*^?d zDpy@Gfvf!~;C{*G-15{Un&GaAeF_TTh*A}HX^6w`wu3PyHy%G$?w1W%FoiGkSdQ~5 z?cmO&Y~rC9rgez2UmFGse~bp$1Q|@=H-8Vogov3SxY@z! zImg(Whm-}K)n;X_u{=Th1`Ho@hNV}fa?J!WXDb~qA@p! zII73-hMmE@q)QPiU3E@)tXIiS{xN2KRQJL1^jm~hhOm{V%5pHvO}PRG($9fx z7=?^zSH5|h6Y5Xzfzhd5G1sXZMqFcHq2@(eVzg-9*wOfS+X;T8t2H+&G$2>IN?Gyp zNYZbmBDmeqKmq~h8v6E9NTY2F& zS&`cl0xoTXteF;kdf5m(UF?FcqmF`HTN?Z1dzu+IS(3}@n}wT)GP(b_&wSu66Pz47 z07sek#odOvFyEvnK2({)rvE7754@Ynt!7o8zG?#bkVc_kvMJn~V1$bUUcj4E1#FL_ zB41&s%9^_j$33?OlM5fUnN;c`uGcJj;^07=%N0j(M)eK8>O~7nGk?WCeU4+FzqYb9 zHKxq9aw?9`SwJ@g4#LmfcC-7pir~ffS?ofd8o!bFnw}VXfTUHM@hSbv@L0MjCYRl0 z>vM0j-+6_wKZLSwH85ialFo0jYmB8hNf&7|rDk(32N>ulyfkozE47G}9%j!GZ z@DCJ(&pl9i@Ghv-?LxO^M?sQqG?%+z$wQt;@liohT(fy18~rL-oL|{ZxZ-;lXP|^> zKg`f^L@GvCOQBPBGCWUQ0S%g4NqpiaW>RySPmH+1p5BPUbE`F}T#w_zhSM`>>8?)X z^@m7dpUGLwr)hkPTz3rANCdvpfYtr($MTB}>F}HZ?AM8%yzlCCG^o(v-M>DDa+NHY zcVrthd|nUsKR?47>c*Z#+~ubCXY)5VjoGH7nrLSDfM{>}0Q=`mg^*p!#L+vMo?Nzu zG}U?2KH0rk+oE;!RY*F$X=XuBj^D)2nZ;6F-65bjsgxdWenB@co&nSR3z***c^0uO zq5SomR?^#IIB4u2Arp2Sq%fhT>_gCQ$%h_?$h4Sb`uikT!@j_l=PR4&iWx3SB#Qv2(*WD#XD(hLw}+7xfe9}_!&BVO0uk~$GmdG z0#&HAU%*sc-wW%8B$3^i?V>;#ITEd40L8j_;-`8}`zUM>lBL{@J9jFpqYN4k2Ca*TI-maqLe@04?$g zCAF?yq4;MfdPC7s81`-?wLZ}z#@O2IYQJ{Eb-Iv+S)G|%O98c=eO?GraUzf6i&@;b zT()Cvp`a4Hlbk)81=p8%W-D{gku^qjvX;6_q)|ci3%bTIedlagx?jRZbY4X*Jf0BK zXIirBFLx577gi*vbS?YI+UTRWp0IV$bu#G@B@IjDW%;?UNY2+tVbHN9<+nxxVfS3h z=esPV6+S;`r_))or;XXH{%k0%e;P@YraO_3bCr`cS8f5-18Pz9?`_gE0t_{)o!w9 z;5+hX>o4-ie><6cc?O%D*@r(GvRL@O;4bZVbrD&ks6c8*?Ud|TuR{$tuVX^4J@bz} z$o5&EVvX9jgc!A_V$9rASaY#2_B~Nc!kPwR%QCTE7PM-kwT5s@0kK`xw^r!J187+(7EhrO^M2DJzYg#kv$q;L9a{ zcqRNMo%w#js{ImK)}kjfe7{q+zW+(Kvt~1^S!Rg-H}64X2ykoTPBfUll~`+Pl4Rim z>;Ld2S@v{2=p-Is5_vgX)#lEOl;v3Cr}@O_x*~fy_y9|C;;eTMd7R^Vo#j}^mX}2( z(*bvCXhhm$=I2%f8gVC>h1n3e-`MMatcT3m5eBug@fftwRWFjM&4|0VI*>q)1p^!*>}A(yiDvtmzQwsHP8=o+pF zrH{sup(Q)YKWu4Z8hy41_gnK>nBo(5rBj*g%3wpGpx0JG?!-=#{H|DXxcX1o|F#}7 z^ucbj#MX->zH0@O0Z(D)mwB*YT8=PLTkO|<)F;P;g=B@I3-kN1ou?|BF`M9I64N-A z?H*fDTJYqqWLNk-5_7T`QpP0V76m6kJusX;jol?YS-Tcmx8DgH2zRkTTEa=t;Kj^8kQR=oVrRxlmuDx5fWZ-!E>5!sqz^q@e zq7C5708hOC*%2PCDj}I;(q+%T^a1y$xwP3#3!X*yr)GBw*>%YdlD+vVeG#I>&R*~# za?x2rTblzM@iS)<`=$I)RUzA>?8IhF^AKKLdk3-?gTON+nk77Qo?Xk`7vF1;2|8$`A#hI2rJ%ji;Qkf zp$ESj@!hNPSl#5`>}b+yI`F_jNO6fL3f+FNlLt2Owr4-e-`vb8ANtu-c+j>BwAvJC z!O8(paXy;9D@!B+yC;#EVI9JLvmZq1(?(ggyRYn>;V|4{Z45HuE!Ar%A?+rnbTqPY zRII~2`JT#(f31b9OAmn8h&`*)t&8hC=34V{F!M;pkm86)Z9^a8s6c`Uo2oJ%wDwoq_E8UN(u z;PP*QFj_uYXcIa^*@jSwqpmHE+I@(0sgTk)wKnXD<^!R7k7n8Sft%x@aS53uid+?jHf36GpzNKaPZA(GuSP=2O24o-Y){WWzBXbjb2Mgu4FcKO>- zo{+$bJ9Xh&7AnNW^A2qBq+sLrn~;<(vV`<{N#e1e^pWEUm~N|0@9cg~b)Bw2RG>Tr zPPxj|xUHn>!{GA1!S|_OM>8E#W{Mj!>nTY;Bxw}{ey}E&jUaM(W11?~#JwUxV@Apj zpLJuqdnvQ69r?^<`W|*Sgt7WDO4x1G>HpCV+4i2rJ%XE={8&w{aC9d>S9+BxTJ>Yj zvGuG`=*O)$6|$G}v^n+f%oE=$@wqFM_{QdLoSyH)12y~b|E(Ruyf1=NT@)>UYfQ>Q zFY~SkwzK!t4GBTx% zde~c#*UN7*Bo#1lUou&cEZJ%AHCw!=}kYm&{ns;pCU1|;^IL>FGn z1tq&d5W0&~r$#weH{lBN@7@iDByFVNm`_aH2hb%Y?}VNs=EEmTB~S`eBSu+o>Ep|F z?8(A(dhy6>dS2HNw4`I%`ATEKEmRv8mp&C|Cfcaq?K9+|G?C%_B4QSL6;>BXh}>5y zn?7hCP0iQ>**kmCwBRVV_+Bc!Rq9R`2JDk%Y4w0z6BXdBM>!P-1Bus=LDaYCEm6_C z4^v8ykyWQN1mj(nbVp?e-5jYRYn=Uut{txpw`q2$lkWAw-pTefHNPWpcTR2H=M67kG?k7>kQ; zMzf01y~rEqqok$H9lgUW>HCFV%(KJ}E|%Is-Q{bLKJgWN>hFc88@hvT-7<^|?n#3j z+#t}UH|wF04m&Jjp}fJ9ZQEWBQJ<}0QkoG;4NhaHmst?=>jV3eeUxo~G>`R52!&6r z8ZzY{msx|NGG<=w#MW85keLrBv*#s!p!!27`)D%^HuW+mmtt4Iiw;F%wBRURt{X3r zYqzE8IfYD^@`Cy&E?^xpGssAfAx9?fp?#+AfNF=qtlqJVo-z7P4x}uFf;F*ZRYVz8 zw~3c6ZL**yiAb}jnuCY=BC%uP17TA>LT8euE$&6-KxcF$NtV3-fo4Io`*!i7=IWq_| z_^t^LFBr00aE~b4>}C76m&+2a{3-7;^PX_lP={Ug+eZ(7KfziSM?$}Qughx2F95&S z{b6Q9EP4AZgFYtriTw1-@J;y<8Fxt#?QjgSe>jA^R{uoDZWze5OdsVdsv0|EGWRxBgQ==U{T*RDeRq650Fb=Y^LeyU{HY1)({9nQU_NcS-V} z(XyUT=dg;xbCRFG`;r3hXxYs^4Kj_?9x}sEEwcaEp})Pp)GuUlb5&j+!hmXyZakDrxs5P_*o1Bi~%9BU1b!8sLjy?51#*4E|&0$Hp z2J^pgia$#033oOIQNLO%R$ycTE8kU+W#*c+b*3RVk3Yai#@``Dn;ys}*9AjKKp;G9 zxgq@L{QuS8TJqo2t^fVhfAjEn4bQ*yJ%7_*{G;RfN5}Jzj^`g8&p$eze{?+m=y?9o z@%*FX`A5g|kB;Xb9nU{Ho_};a|LA!B(eeDFI-Y-YJpbr;{{O4v zQ8||ncXjW=>N)+fXYY+*XBG&nwp^iix}-uE_gDx^DT9|bPoc&C0kmbOL-v4W&|k#` z%!~@*<;DZ>pFYIj`Y4_DACUE&b+UZ3O@id9*tflU;gDp7eP@>WGfwceES6pPVJc=tpC#o`k!<>>V|WOy^EfWtkbYEw~&SJr2^^>m{Vy z`L4|JgFbW(dO!;s{fMrMGkcmp8u~oUU=@LKU@6XgnQN60#g|8@{7)lDGP*-LwPq1D zA7#=*(n5RRT*TguEhiBl8`(Rr4)X2jPnH02-*+Tv&e~C;QWv{Sv}K*X-4b7 zqvsrmUfzc;kKV;19v+}p>33-8Mo)-5a1{2WhR*clW@=?0yZqHlyn>A zPY?L#(}iPaz@GMDSk)d0j-OW&xvX``7TRaUv;DGWQX@Sq zd)G9cXun$yN5y@<)JAlWslCnFD7D_~Qk*ICwJoFj{gc3b6VPrW?ufJh&g}k%-E7cp zC#L3_T7DO{2opQ^VCK9|V%uL8F5lCDNRvvMBp*btzE_mxCN+{LonO%DcVozc&+qAT zldf#v$8JLVm5)M8)D`Av_FB@-xDW=fs3kugc0kIU7^1&)G26<+VCBY9(0`T#vyz;G zdZiPvx+I@2?zIw(tNSpQ9gpb?mlinLb17tRbfZ7tZG&dBV0hTAo@$RcOVwRRgR+V` zR5$O3=gu$av%_-#YaNezkQ&pzw~Lmq=}RuJdJVfXYa#7U5+C&7In3YQmk*fTQs#7G zGEXn*%s2XcW_J@Q{1|fzw*4XGZGF1nZR*WDOk3Fe9-HH6-_`#brG%{t_26) zWwO-V-XJJRS=qo7;#|Bu*K>;^HwrCb?12<`8?S-sQx=Y!qgKKCE>40uk9RTai*-Es ztbn}|elUXx>HJVhk*vmXKixj)5!0KhNtZd~l82Ay(_)PZe_@na({3`d`&a<+HZT3JWK4zB;$h|T@@1@?;?c3EN% z*{q$v>4~4JtmmHo*n8-A@E_S7A1~?zY~d#AbDGO~iF1BRS9_5*-|jH-SvpOv2bg^L zF#D`w#P9r;!_)Ppc&_v)JKwDjIHad@nur?No%pS z?yAiFjWeBd;|ooGrw4;@E3^Fa98;}gS@8rjo+r7)nv*`0GXtK(jh9Qv!XZ}tl#hfB zzbwv)qb)e^^ya7LB;bw4I`Y1HA+I`|4^^6_uzo=woVHhm2Tb3A`^}2k11BZ^(;^iO zo)6;ZJDhm)V@k%{nt^i+b+}gi4>mTpn(3dIfxVnVnSRu9ws>(r{`x^QOmtd`zP*ib z>tzjCKm9CwP&tst_xOMVzHH_S->)&b{LMV>aaVTJzl6H?YleUmb*#d64?k$Aitk3A zX5;LR!Lsf5S(2oQdft+-b4Evjsm$RuGH*UEa18#D8-PapV=&?31X;05CfwHk0ZL%&A9*(ETMhSnahm*m8_7PpK7~HdGok3yB!2aL7(2G(H(7Hi zoLAgw0N>xUc#Pdl?(*;=e9|!ECoI%3vd<(mkX&OYj*rIbgO0doUnCAw$)=%Sig73Mjka41szMEkMi{=%=tX2NP>@j=Lr|krMNf^&nPfq6J zrbggf-knzsh-7*Fy7Eq!C!)1)9^DtS0J;Y}W~&Z`lsmdU0j4y|C>qfnh6>bZmF)&GK(TH9d8sGF>;(H+vK?xGfJ6L3k#4%{$Rmq&yl7Cn1H z4o**kRl!YkiE9iWFzXdN>g~V}Ui9P>roJFsHD5u^wIghA_cY>j?GA}gu4A_XCQ{R( zN^G2KJ<A;E_^Du?=H=T=(%_{#CHy_#mD?cM=PVIZm#a7YJ2p$AP`yfQ}6}xocf4wVZMW zz68}l=~Pj3+Bb&I6yx54QBPpFf+;5p40u!YCRS;ii2TJo9wn4eo6Na<;8Ht$r}vF4 z{}{ryNL;Z=HJp2m*2mncn({k}YtZJRH&;2e4IbU`W`BmwW|1Wp!YWNeHtp>svhhs} z-bsqWTE)*$H%}M$%I7eJ&F)-w;U3$6F9zQBY=Ep?!|~FO>3qQHqqv&if}c-k;F&d! zLSLQH%vIF~O>6S`_uisozrKUz{YZi4CHd@qn=)JWXsC6XVhGPWsmC>LJ!geQBk`7e z99gSezpf9JH#7`)^7diRv4HbiV#gR-y8i^pxvNPpy?n>{&GNc(|G=ts2a19$uzS*|%j; zeb>X5r|sm+%K-TQI{WgloYwCD=2EFd12UvkNE%S;+4p*=$W%f}G^sQX6d_S5O`=)S zpizbhMe4cNouNX5DRam%XIAFv_nhzLoa=qx^F80={XPHe>w5P3-1lC4ueGkd*S+q~ zI-A)@wN@UZYY&{KtDf$m+xO|=e5Yx&evKKl4i9ybs5GyINGNk09JGFTq&} z6;4f6m$PeJ$P4#$#clVV@}?}@g)b`iQg^zKZamn)8zL9SRo(4}56^5smHdzBHTky& zjiW9+x69=M)%p7^@`QD_Yg?i@tAoDOC*(2q*7!Eq%nEJLydFb)n?h0djSdQK2XGE& zyot-EI?{6{gLwDobIUDd$YISL!VR4JXz1Q}YBEn5m0cCkPuh~HG)NOEUnxu;W5cOh zoTm=QCFqnuIc`#fBG}!m<5k*LP(|@hF5s~gl~N%X);obzh2;y*JiS83{_rOH*4A9; zs2rZy&`D&u^nM=q5_yrYWXRZS1rWc;5@nY#Zt2wkK~`cBmzt(XZn_=k+~$7gNXuDt zdCdeNEmFk9Zh%{_9YA!b4*?Fy?Sa{klF zcE@aN8K2K--SZ~{uAUWi*=*ylt;6wxzn$-6xe$|H}jVcyJ22a_K5@aE>IF_7Pm2uuj{^QHIR_>V%gk z*71zQRd{jDw>Tp!SuXx{jrqq9x^&R4OnP#;91XHx%S9Jn{4<~MclnCjqud~Hw<<_! zEr7*i)nTB63=Die3?g&&gdAYX;A`Ly?#X&kWObZqPOyeG`-XvX?HCYGlLA4M%0Je9 zzRIiRUPpfubT!HG9o$^`h$et8SL9onBny~38ci@IDyqhKiA`9a| zGCvV}W^NZG{)k~c<$LLdYkE+>B!P^5mc+sbMWCp}WxDOl0j^t{q0TBbXnObogQv(c z+biyLoLo223z|(exS`BFs*f0-6~^q0WG<6lqIyb3EU$PFHt!iM3b&EsDKJj3P@d@3V?QiC|tB4rM;gP_j)APN$V}lQPtyX~K0NBTDGO zlO`CXkq?I3ir6}pQYI=@!7jQ?U}K$k)}Ox=%fjAUvRTI)nSQiCT=B|*VYN1}%;f_y zk`Q4wrP)mFo;ZI1D}};WB23Kt0I?rEu_1H86=+#0%HNq_#%B2QS=RVchz;9MeP}(^ z?+$})I!~xpLIcxwm_N6csFi%J&f$0m&^9@wv)YIDCCgQC_S!}#G0!! zr+(K-MEXG1Z4iJ{t8z)kd1W@aQ-pnf_kc>LE27%qTV#=C4GLy>uwehqtiZ61-b)l? zqB53f5ibijDmtmBLlM5&I0ikMpK%wJmZ9Y}Z(MjWS#Wc=586o$gVUolP?s8zj*2Rh zswl>ydhReW{Si<<(n`3ECrXeQqfSRS$~UT=5pHI8yH2(9?y~vXd8F*=RrY127?U30 zi7&;&;qy%qA&>hWHq1bY-_K$-W=N@^eS8DEDC8TyRTIEUD@9RbelLkt^n!)kv>IQo z7R8%2A-Fp5Hg_l`02Y{;!kBAzEV6YG*VZqc)f-=Dnv#|9dQ1SoqH5ND*o=J;52MZR z4w4c%vfCoCUb2?c>RFtue1M2~2Idb>OprE3lIi!%R=bN`e#9k4+J7Kd!c z_5Q8oyR##Ez223KA0mQFq(#8Ub2lyXmS=_YCd0b&Mc7?E7CU@hVArzaBuvwWOEY~( z+TRWe__%*i5ivg$!B+JSWNka13iqM8VIOA;kM0-n znw<>L6PxharDT>C-NLOixkdbLsIzt%G3@L3f~94NL~6PP6!xqk8$Yy?b<1~iH{Kh8 zujzOCRn{GOPh-(`h%%WZSBTYPcXA)h(&*!5mDFL;WRN|Qgy(D@;TDdC z93&jClVKu5Vbhn{>~5zm>Xsd%l80m&kITkuI@vT|$fcGV$8?aj#ZRc}o@-?OUvYPXi)Gc8HDN-2&`qKL`-Twq}y=_lc<|cu4w2%jRWx zB~|Y{neVFMaN>48y&@yS?4A2F(Vkc`?7Am@lDbJUZVC51y*vOXjMFG*E=Q(*eUAg~ zYciwF&xqPHTTXDh3iEe7BlD6Iaewh`k~z^2Ze6oreGzgjc+_s9SIMyd^EkX|QbQsF zQfPE#IJ1fv2SbH?iPxgES(d{oUf;W2M92n*4+>r*X<5jiqXgfZ^q`p5FjleJitID~ zM2D<8OSC0caH7-aLDcON4SGCZ5?C+6Jl_O^+LurGeP1y*QNa|>U5tjW9VZ%g+R<`WfiYWHT zVGKFP-eTE)Gq~1u3!8`NLc+{K(jer>$nSQ?>8D0Qps^-dR5p=F{ZMBy&WFj%2gm6v zYcEpwy@toTp~qa^<$+bk;JFfgcIV!0bf2dL>xXP1QbBWxtf&Edp*IDVq*ao0ikEqL ziI<6S)n;n7Igje-# zL)g9C6jCuY;+NwD_jMMmkhWuwY{Z~$i#~K&x^s;?CL?}agc^G&7k_jjGY)nG>7l0N z{2FasY;VJ!f2!v#I(3QSMhmJqMua`-yGNSsggMQD1L5fQDB}L|E|IIK;#PiY!y^y! z*-WXS?DXl~=pgJ{R6bn5D?PD|7-tob=A$nfPOW&ws>X6GSZg@W&rHEG!9>*Qi|5v4 zzY{b8AA_^asF>AkkT~?1B&Y45US+eXa76`MB4(3;qig9+hgWEIEtG4`9Zivg+8=Sj3SG7;O_fYk9nC^dccN&@ClYnys5xnhWogah z!RT=q@hX)d4le5C*%)WiCEtXv4%d=;>ou~CN%znF&w0O1ua3t(SL<14%CPxgYU?*J=Ko5Fi;$v zX3S-;EI*JB6&`G>>1lM8Rbwi5hLF42!dPOHhbFh0h5V62Y3wl(h_@X}%#Jj24iC~u zjnqam?W+wkX1~*Z;SA)NM<`db39)e~t}H74BQ^Mk6-0 zTZYMA$;5@#F(|5%PYTq}3cRmcvfQn`r0U>U*r^u5EnNG9wk2CYw2$!l`t^gg8IRzc z=oU6>?GHh8VJ_CHjlv1ZI&?$JEG#(@j}Bp{XkJ(%R(?D}9CHbYSNFmfA_ny2b9qoM zt0LNym5J#Of(p$-4yfGoq~5iFjxw|3Tt{z3r&h%N?l+vn$w8URMa)<_9oMzTk^?8^ftsW<+tS|` zh7TS~(=G^ez9mK2w9u8)g=<@6m~DHPvCT@i-xNwQP18W zm?Z4MQ7f^Hl$AZC#nP=*Yu`3fiYoNtewkn1w;TDjj}zazjA+i81AV7M@JRS-iyJeR zfTLm&Gc+qCiCXQP&AoJPQ6(dZxt@6Ka2fm5VZe%}jpZ9k_NN8cUs2idgdVuqN!8EO5BO$P?}smh^)!vz`E=h)%gCMoy=Ry z>}K|BY`r=ewM)~nvNQva_il#w507$6Q?9b$s0A!%)Fw9DO$Ya7%*F&?8MK&Z1|9i( zaOtW>=G^lg&Ry?M7O+5GQ&KbiP1x_ja?b^<`#ciPWcvchrn5Qn%IvkwRZ{!;mLS8q z4i#>e5sBixtbAAkq-YxwF>`lFS*1y@&uu*n#_KBkzHT01;kTx zp?~cYvc#Ul+FC!fQQF1{YLps}dGBE4+i~_#H;ovMZifM0b=*#moz(V?JBDO-vhX`5 z{EH<+U}njCSf;ZHJ@*Rv>)V99gej^xuPlpQYVt#ofH|!9VjkoQ8rh^NigZ}@e5x}| z6s-eGL3)Z>Bkj)PYz76xWYt(m**g)JNEveCTZR&D>Jp5JkpU^+0QzXy$i@_>csR4` zJ;(2{$2^yNSmAn(*|j*K*}jLQc1$-37IIoS2dU6A^K$UI_90s0GQ06P&l2KAL%_FI ziG+;13EV9mx<|qocLyr5C2cQ>kKau8FsdF#T%5pEh6&%BwPNYJi(2%XsRh1Ynhh>* z^&0u=ugJXGO76_MEC^lnfR5}vZBgFkDR^``h?v|JVUE)J0-71e7VKON8`D0~uGv4h z<7>9l%#?N(_U#K-KdJ-Wz76J^%~=LDhWEI%^Zwk63Oy2cON$pnwIK4Y1j}*e_=&ml zjmigha<%t$xKC={FM-8&<-@oxScw%}v}CDw;z-;D zMOb#s6qnVS)9D%psCLzoGPP4%{3-6nSK#fZtQB_{Xtawi3u+ejp6_5vl*8$99!EfwlrZd~V zEs@n)UE|d6B(bI%EhtQigQV@UjZRVs1zDLmZgmD5QPII>SzloD!qnNZ1%tS?)#h+O z{{`8npT!z3RI(23E-q7A4ic6KnB#+Ts1ZGYtD9PIiQ-sxc1kAtre|{>w)m2s{gPzD z?M*mf=04~QKZ&KkDIW!v`|!zAM=!1tR*Mp*knn_4cW-(#&0gwxlg6^*2t9T*BJ)=VFGakRL`vxHjw|!o0qwS`6(-WXC#`KtkjZkXE#=y#{UASZJRovEYiP;e&>GnDYHY>>h9<4r#C5cPf zWnU$Hx+xa6-QUXk4rxIDiKPO~)1~aCc_rJ@(SqNz)7VIwApGVY%vO!HvGy`c3yfMq^l2S2^VeY?Y+qm9u|f>goHdtO@bKD#DZ(?zmg?M zWP^`z{rdWec6L~P0(}25h3T7r;%$7=Kt#9Klac!;8UXgp_|+d zQ^zMZ-|5Qq2{1}W0t&xN!RaTPG5LWaloVuvYltaSdM&_bvfJ4$u_3JW-3Vr|@fBvS ze~z47I&Li<&(5g1!Ts!5oN*(VWO~*?)|^9RszWU=uHFcgGb(sfmqdf~7a@0VRW_#_ zU4!S|XtK-!hqz#weDZpM4H@mM3D>H}k%9U5*}8tEpzqOwn#C(vQtM7?c-xBTnuq*Z zZ}N9>?wssdOvqA%)+FJWm(?4$hUc(DZo6^GI#Y(H7&|XNk}mKH1pV%Mc5sLSyvXe% zom%}tX1$QBzh>bt$05&y?t-<+#3+YN(a)67Q%MjG}s!q6TBUg*+7#`aCV$6s!19*DmZ4dMSfz9dm4(!h%|Sw z+BBQ_^(C<>tJ8=w-ONnpFqYKlz|6+wLgDbC%(3V^o1$HZW(S11*V_Otqcz!@OIz^n z`00#j1hP%?^@9AiNM`7Bi(UPq4Q_dz%(_&RpI7>k@CBRMPD_9C>S-J)cz+h9w=Cke z3_S=<#!rc-NennQ?}CNRk2&7ZVCFd82=9#cA$sR)X?*S?)RbI9G`$tTzu1hOSMtNj zYetZk6^Ud~t~vC)e?Vr=-AX6n_J<9xM`2D)4rlTJnb(C*qBvg@LJn8) z^fOQKJnqe;uS=&BzweW{rdLn7(szd3$(&gz|5P6@dq%QT>vn;m%K_mxfEKei5GQT+ z(M<9F393}62JJ)UgRSEjylEeXlh(f>hQ?P=NG%WRaznVLx=yUM)n?YK`m^OZH^`Zi z(Zcw&hVkE)v1!f2;avS--d?FBR_ZW+Oua%3 zq63&}W*M*nBChv9b(?fvZ4JQ_$Q)Ophf8_nKjs&DV=)9 zx<}80db5)(^iv0ki+BWKHFwGKk8*rlz7yN0AP?JC+-CuITY!H+3+ncMX5Us7!PieK zIpwYrXzi72oP5!i)jC`S%Yo9McJvHNBs(H`R7W?w9l_RK5axz9?qaP+x3RCQ+EMOB zJ{P)oB+gHG%G3U`n#CyP6XzcqByqeqdz{;a9-rUfh-hmljjV>A+(lgPiQ_2uKo>T> zS4Iu#7;?(}KB=Vn+>mK;yj9KeY?t9uCaJ#+;~vfif9Ls(NUAdh$pOr~^FDd6B@35+ z^pcmNJE{M`VzTw34lL_AMB-n>As68bF57po7Yib(Qkk%CvVs9DR9?+e*G?c6!v4eu zlJjA}M`tqin>C{MM&|4y%`S)+V($c1GEU()w5@!FlTA8dpkhBzQV@dy7kjAW?IA2G zL4rBYnFy!O2;cvwi$L!@U)o}Ak2>$tNa5&C>Y=62tvsrVmzzJ3#as8VSGqHxDDpLW z2IgR1sSg;Ci4FO^hso<1(O|H-2aoV?(9D1(__>3yu{wjef2zN=1u?#@xHr?3%V#kI z3sA$K@>b1griCRXq^9@<8R|5Yi3W~?{;Mq6*TNYjQmPFNkKIGn4do>1P}M)y_pb3% zfbyD6;QR6oL^c%heWNS*AC=4bn-(77k8LdBZyF;7*~**2>U1LDx5NAn*u{6AwTqw7 zvxmQBbSeLzu0JyN6ke*5q^bdhB=uoDRT&Y*&G_Mj55FCxO#xRaIVh~Z8P!f_kM7{6 z7d6to$M?{hF`^cetc5w3^K1W&=jD&S=Ks1*#+Juo=l&EdJ--9VCOZabSNSK>+S zZXV&`-ofF1A;AX!bIvU|WMzPNdm%9lS`+2SUzYg*X z{)y51JIwF;Enof@opgU){&ziw|DCvS=;h0e{^mjad6R!iE%WmT`@csr{-u$O4SpX< z=RdA#xnHnHpj+t55Kr&0u)iog`d6iYSN9K#^78&?i~2+6Pi6lLi~29I|7uZxf&Pj` z{Q~Gui~7Ur|AP5lKcTDr9}Z*u^IrdV(BHND!}R}x{Pzj}&0WU+-E4ostoUhNzm5p_ zuMw+%BK~m5zi9Fg4*3i6|LTx`5&vHu^6#Yn)gk|Png4Xizmxh`hx|k4*BtU6rGHm< Oz|Uv=Z_)qvIsYH33b%#; diff --git a/test/data/model_2.1.x/version b/test/data/model_2.1.x/version deleted file mode 100644 index 91dbb1711..000000000 --- a/test/data/model_2.1.x/version +++ /dev/null @@ -1 +0,0 @@ -2.1.16 \ No newline at end of file diff --git a/test/data/model_2.1.x/README.md b/test/data/model_2.2.x/README.md similarity index 100% rename from test/data/model_2.1.x/README.md rename to test/data/model_2.2.x/README.md diff --git a/test/data/model_2.1.x/config b/test/data/model_2.2.x/config similarity index 90% rename from test/data/model_2.1.x/config rename to test/data/model_2.2.x/config index 2d644b413..2eb204ec1 100644 --- a/test/data/model_2.1.x/config +++ b/test/data/model_2.2.x/config @@ -3,9 +3,9 @@ config_data: !DataConfig data_statistics: !DataStatistics average_len_target_per_bucket: - null - - 13.44736842105263 - - 20.568421052631578 - - 28.053672316384176 + - 13.418518518518514 + - 20.319148936170205 + - 28.016949152542363 - null - null - null @@ -80,8 +80,8 @@ config_data: !DataConfig num_sents: 1000 num_sents_per_bucket: - 0 - - 266 - - 380 + - 270 + - 376 - 354 - 0 - 0 @@ -91,8 +91,8 @@ config_data: !DataConfig - 0 - 0 - 0 - num_tokens_source: 21324 - num_tokens_target: 21324 + num_tokens_source: 21181 + num_tokens_target: 21181 num_unks_source: 0 num_unks_target: 0 size_vocab_source: 15 @@ -103,6 +103,7 @@ config_data: !DataConfig config_decoder: !TransformerConfig act_type: relu attention_heads: 2 + decoder_type: transformer depth_key_value: 16 dropout_act: 0.1 dropout_attention: 0.1 @@ -134,6 +135,7 @@ config_embed_target: !EmbeddingConfig config_encoder: !TransformerConfig act_type: relu attention_heads: 2 + decoder_type: transformer depth_key_value: 0 dropout_act: 0.1 dropout_attention: 0.1 @@ -150,7 +152,7 @@ config_encoder: !TransformerConfig use_lhuc: false config_length_task: null dtype: float32 -intgemm_custom_lib: /Volumes/CaseSensitive/Projects/CoreMT/sockeye-github/sockeye/libintgemm.so +intgemm_custom_lib: /workspace/sockeye/sockeye/libintgemm.so lhuc: false vocab_source_size: 15 vocab_target_size: 15 diff --git a/test/data/model_2.1.x/model_input b/test/data/model_2.2.x/model_input similarity index 100% rename from test/data/model_2.1.x/model_input rename to test/data/model_2.2.x/model_input diff --git a/test/data/model_2.2.x/params.best b/test/data/model_2.2.x/params.best new file mode 100644 index 0000000000000000000000000000000000000000..6ab4f14fd5e61bb39542eb254cc876795a6f3fd7 GIT binary patch literal 33554 zcmeF3c{Env-}jBFh%|^0NrhB~G`P;*A4wWS6iI|oLdML~MKVXEB100&&_E<{?Y%W< z9;h^^G)sf#N#k?wwSK>6-S=8g_wUc=pRcvfI;PjTuIrq=Kl?Mi_Xu&Z|M4~cUq7Ay zo;UDc-%9X@!oPku{9ot)j?CL_3SWi|fg8&^$#~U6WL4!BQG>@Mq%lJgxxF8cR`=c} zk81Vk>eGqDCLGfbv%RUu6c0LiZ8V*;W*Ie;3BV_X+wsux5p+_|8#4NwGl>?C5c~+( zWm;lzKrnur7%OdfO{Gh3l2$1<8nS&ey#<@-D%Gv@T6QRP=8));Y`e(0poVj#CbYUz zhVDmmu}Eq*w%7lOq|OA=Zv%1k?2v8L!GiQejxYC#;?NQc@w2gwhXOAVyBlT4C~2ZgEjJNencl zipxvsS;yVf`m!}$8(>Pe)#!Om)hHVRklxr?o`gZ|s;0WyIeT3kmnVHZu4KHeJh#+!^?W=zZ9Os6--%;B0`^HHYPN;F=6 z47olcLojI07Ln_XU1a^#0$S#4Oz$RC(T>|IsMVDcx=K+XD*Wt3UbxMn4izeNqc8=n zxWSAn)1SuK^icUi!3@=nL@>*M z-kq06>^-iMi;H5>>*0sV16Gf0Z>I@%3EyxUdrymmYqIIJ+tX>^O+)&Lq<=1* zG`NzEIW(N+!28ydeJUjB*HpS^$9Cc|&_P@eu0n+`hMEo6<>JvDVp7>!wTlUve!@6UFr9Df@x0eF?;p;1r$EI1t`uceycdHqxy;;hQP*mjF z7Rt2by{#b$J^D1`l?FY(w}AQtmD1TqDyZefgtuXM%3w`;*tUeK>dvPhlMAV-uLd2}sV2DSv54etm_3t`wVAspq||9THs$z0Z>9tyYCoW*I=#PQvW6VPRDDYtS& zlECF+rKy3#WL7_|5bwX>2R@(TVEaWOocu0?XrBGZAH!qAn2`TW77?HEIBq=t+_VHG zCMR+|H_HWv_^PQsmy2(N=!(p$-68QxET|mfU-PRFrt$16{@9|hm0CxJfX|lU5U4p9 zcNT{s%MC@GQ_*RGbp1Ee=F4Fa2o=<5(QJ!Q9%BfL2IZAYq&HIUbaO+qMHytycNP)p1;NOJ1pZM zZqP{BS1}8gO-jIaHRULK_Em1=2uaTV;zYrhU!!Pmt{ro8HUf$62v|2w2!me<;W>W~ zFYvsP_;vhJV+U4GTL{xmZpAR78NF41!I_4RcdZKf}n3>JUHJL!VvzM74olt+hquwl&i|r3j;vEn&1Od zJ|p$DswlY7g6pe`6m$$)2@XrG*yohv_|$oSII0s5*E)o-@P`mmd0tEN7G5%HIcuv8 z2eZOUxPwYzpA|aD@25RC@o0)5&}Kc9+!@Y_KBwVRwZ0&BD-O!e2_c*xLm@mLQV<76 zs=`>HWh{JNaTj|%mB*F~OwbC~wcHKGT){0VIjk@&hxW)$!}WIFa49|x&NCrw9}oin z*})=e5`;YzvMy&KB;9?1mkv_F2WOch~vyuzynVZS=Os*5me7uD&tyYJQ zx03YD6K7boG!~YZ2_fx^5afBj!OV=cSIM(cHc`+rFa%PU&%&preNaP49+$qkMX-6O zF1*UIhnoHCDckD^)9%N>y&XdE_$-8X{22C$Qo(b-WYFVgArN$YIGB3P#XmlVq4@G5 zZpNe20;U=VGG5^@Go_Nc=q-mbuNa`|LeS^uP$$oyYw`hsw=Vjd9nyN&f&0%2;cs?m z`1BPJJ+lEX{^0^cOFxC6$|Q_*Ls1Kj4qJA$9HdEJ|pC`W+KD}7L&VgPpQH{z&;MW}Px()_6tFs=VzzNP_HT>M4G6;w~ruwyxB^FPFySmME_-bd?%S{7^wGmLFC4}x5LO9R!@?D{D#lsTPJG*htH!~QNn}#*_ z9YS+-?sK|NWVzXSJ80>p@ys&Vn{Aq{1NFzl!9tOL?|nij=6Suw47@=w6mpYlahRbc z#KmXfD=J6Pj+M{22!4%s^bcqIBOO_+NidTiq6MuoJpU&iY+*{cRl^)N_RzDH9rNtA{H0*+Zp4G5#HV4q3nW#kmw~a@qR6u)Eiv zEjue@u4(e1b!j84YUkI%hx|R{ZJ@Q@FsQyYfwiavlCi^_phvz8AA52I9hH+r`z}o6 zYPLE+V}k{Q93gutk^yD$jX>+-VPlUFDtUemZ^!DNtYk`Q_xStA^IwByAj&_IExNghnj5WwYqSP$ zZNH0-ak6OEWCQNGz==(s6+x8}BG?GwN9?vT7;=;2L9?5;S3G~x^e$a)FJ$+2HPDA| zyudo=Fy3eK7}eS|H982y0^{NSEq?6tHhlK8I5y

BoVKdK{l6?G@O0(v@YHY`e09gEjVAoBbk#(^uvU+I29TkScg@vo&UhHD#-`b47 ziu_^Sc-xc7NY2lmvp5;czCQJXn#m)XPHhGbIOhxR@5DhrzXrzhV@IdVm(6}S z8S*vU*t~>D_}<=uw-(A^_%Q`7-sQwSyl{zn&Yuerl0T@~;Baiz=EIMXI50WGKRcc` zd~45M?01IlbYHe~Su~6Z>ca9JL-E5A1}NpME7zwg4|*|@kY#e2&Wu@sf7GvsVdZhK z;gk^G^J{3CWG**ue-s!MZ)6Le#sC@Bjg6$`ahTo=bmFro*UGOmZ;L#zXjKWdI5rbs zDe;EADRHo^O$b#y-*LkUte(b$FGRD`CUJ1Np$8AlQN(}d&PLxS_;3j^gP_=%lDplG zH2bs?ZcJJSxju1FMTD@L=L_$ygTQWINcV_m!Iu1Uv3`Qr1dPIQP8>2?7Qp#i-lv79 zF48%5&q(}=essihEp%JO!NC?GDDr&Y$Y01J1>oCvAseJ5gr*Q~K zwB+CjxLg*`{^mz(Z=V7iiaV&Ggg>+9$As&L0h}42jg79@qFrj6xz1k*qR#uUB`GKH zgPT*x^bKy1bubn#)bKt$@5AiLi-ol}r!XZw54P7v2qV3J<1Kq8V~=O{sIDP}lYc&m zK0kq&-=nuUW!p}&PS+JaX2yc=etr$+`NpOTc%;26i&1c3B|G9_udg_0T+_ozKUbnN zi_^Isqx{&p&R;b8)i6*Td5iQvbb$t+SomHpgiE{+bXiV;`L##0XV>R5|DrhX+bs!m zC8y#1k*m?&N15C{YkT(WxiuBUXu*O*vh>n6XSl-qicx!b-;w7(y%jRYmGNxx?ip;E zdn|Nak%kY(hPYU74SE}y&FRlO%DKk(ArC(jn16T*m4j7aF(DS77Yktw&%17xW@<;) zv--mEtmJ7lT>Bvd)4gWkZFAQmiQ&22-VLiDEwBs6W?F*nW-HqCY$c5P69X2z_*jmg z8^3HH;FaG5tYtus9ny({r*6~cGkhnpup0W6c}(Nh~3{rNr=4*JW3@&|-fQi9Q}t^zJ;?R;jrI0j4u z{b7FBcADC35BsxY;9Hgu3V43P@M0R9@(cYf4qMT=k-e0Ohrh*P!sQX{BG-+trkT>L zwm=BpKLQroSm5Y}DC8iyo9mspnYwNFhL~e9kS%PWp*QRx#w-TZQusLVBOi;IBF2tS zhqHP$B;eu(Xr`lJPs{>5;zAtqDBQz+_0@;Hx~AZvE`*lclhkpQEd(gX!01gvSkLoT zZd35%=k8!b(z*IwK45rF3BsxtW9xSbD8R6sYxazVfZR`5#6N50kWTuv+Xh5$qv2N^ zKlk`p`bLT$n4OCOr}{TYM}Hk$exwX%Zdl{AADdA0xk@gn>>RdMpMzcJ#={-8hjh{k zYgl(C8g4}KF$mAg)-H$Cx;XG%=)&7J4-oc`g&Z+E+%hN`ZFZ~TzW-EUUGs*}ed5u; zWxt`5E?I$JaWoW#@wT7muO!|PSp}^Fr%k1JewiznepZ8yDfalOVk*k|Qo}v#6SC1# z4`{IYM#w(-lfK~=!{y*;unp$#2XAkjC7sy~huILf@G8EYyb9iZ8wV*X9I>=|I@*zS zkjvR3#RL_GOlq_*C^<{B^s0rBVHpiyHV9!c&*%57WfPyj#r{8k;#i{G4R97ar)$w3mIj(TZNREjsEP(Q{(U7rT2zPiN+k+QYCLSChmbqNFg)O}!*qWgEp!_BZh_) zZ!%sBjF-i+zs2!(ksD!N`xAPs#1u3>Er9Jklc4>FC!V$_7gfJF#+lCEP5b7rWgGEy zD7&k{`$FcRoEHU8RtX`7&jU>OoQUtF8?(--R}7{Mdq zG&qHS>Vi;<$HmP*lKj*?gAF@+l03t zkw9nj>p#znJf2Y1b1Jly_Xj@asX)^m6KK(m!-q@DP=47ZZk6j=c1k%8S~hTGXy+1U zGJOW*42y(9V?GAuW3afuc=q(54*pvn=;(;Yge1hm-||3?X5-np8U46naU`r=E(ec& z=fK8eiMX$^5+z^lhGJ>O(2We99DmHB2H24z~0T(Cmxg$OX6!k{H z^7geL`m_$#4*G>hyx^em_GbL~d^HOAb(0H86=JUwkFfv!dsHmjl~tYAhfF~Pc#RW+ zKhL`y5r9%m9Q;1B3g%sZgWv4IaD9CW_PbGw+_&H1q?M+zZPp5S>^muz9PGhv_vk^W zbOiJ(3*iwTLr$@tLH7=eh5bHD;Frfk{BpS&6n;y^H+v2uyD9fLHL+NBT-}(4s*hzO zRM)YpPdd1PGmJr3qha^C20U@;e7Ih@ z1z-EsfK=Z-;ree1VQBw4cCmaZyMH8*tx%f)=O%@Nj|A^K@iFkGsV<;?RSG{h7lY;L z6?p0!OPD?*8_N%FLW5F!xi5th5P36>>1M5FXZ=Fhr19e+y*~_g|BMGYp8tB!3-+ZU z`nNn{pU*Q|nis?VmPh=da0e~V_(Jcwmt%*VWc*5PFs#&4 z9+!l$D_RjOc$_-uM1{f4ck$4}$KYc!7qIE=BZ>_kqY8s?7~FXo4;DPX`(y|UkYCTfJo=3|^(FZ0Ol$byvmK`z zP*nc-9VeBt7oRe|O=BeD*`0gQY@e(OSV)AyMLuSze!~0T6JltuUOemie1TJ0q>smo zY#{i{4xDP*hEk$Fa&gllVQ}Fm^7xgIwL8T!r=Lpju$}jFdieXt^XTO?I6i{EKYhE% zy@hgEjI)KUoAPnA#Ywb&_!q9~wJCVMS%42(3YqfGI5z#oD7c>%3h}q&A>c8eL-|)d z@V{z*|F1fr;@fL@$*XaYgXcr;x;o4rF2ZkaN@CAeLvBb{B=SFZ1i!T(3YHF6ae>wX zy!*m9eEqpB+V=1;lFIzU{io}PR_2;I-%rB}>ynU^d@?6G_)xTR->5l(79!NqIUg(i zN=VoIV7L#SZ1VI6lI%0#`|8*tT(FULI( z=WsKGDfrrt{nW)S6&=|!7+d!S5tXY$;Lr*KHpKA(9X;kJ6*QeD_tR`qR^fY*_`bZ> z01CL!&H@FvT^1-CD3Dt(1SodW6Fl?KDAcdL*+)h7MFp5q*ipiYA{2u8^N|wvY>6d&o9pO>)2BBqw$2CUGl` z5gfMvMfO=tCDIaxbiv&2R$141#AtD>D8OG1TL&7FQ^iUoc$pMEf7p#XAibTdTB?fr zTZW_DAVAHzz!}X;CeNhW@dKI1TwQu5+NSD9?w!yhijsp_`jjHBD?|yEMe2d`;C&*` zd-GfM&+QduOr3-`=Z++@MOOtu+OqPMDCMy)@&4XQ9^bL#>eSB9X%Boa(&^R~#c!TS zHm|QDrV*N==&KI|XFf&}^xcxgu1^FMRV22+gA_j6N*q5w5G@Ehh}H~R zPqLPmlj~P2h(y#pqTcs}EXcAJ1?~MP8Zqn)LGzaU4?AJ-uWIN2%TBxrtffc&N8=zk zj|~m;(MeGfRy1qJrPe1z&MHNevs?$`Ps_9wT8iVr8X_MS5G)rp`tbeixQR`Hys5*+>SGWsK4pF!%UK%SNS}JNaH(>1t(wVCo zqA8zJAmY$?7G57hz8zagb6gAXy%H-beRCD{cybh`@3@LrmeC{S1MygS$12viU@d#AF##T=)I##D-OTC2Ha48}(K{V-L?f#Jb*LAh_RJmh_^5Lt zo8g*m+m-gPgU3%&JMS%&7zSe%%}zYkK#%EcwlXM&#SrdoP_>YIf zsT1gSw|(fx&ul!fxq=A3^pQnw=`2h7Hzc;pG9BYq+?|q(SLFD?_@2S^a4Eybs+Yn~ zZ!PL-@{F$OPp7Z09-)b@ru5?MI#g)thYu((0-L4HZ06_*Y{_0-wpHAMt<9#Q^xihS z>*XOl#6jIPY!KfT7(~K`r#=O9b~X7}xCDph_28|U3FzykXT;H0ljN=vm+~~LD=p+xL9xusZ9jb zZQ6j&R;aLB5mGR$xQ4}@(q^gehoPBLBgu_vnz)s_N5kdS+3okw(B!xxaP-~=1IPe| zHoH)TH$_}pJEe2ScH@JU26WE-UY2=2lg8@*!rux8XhGW@(%5?uja!~UV?Ru%Psh}- zSsTj1YN{_>wo)NmH)cbvTM$uNF4jNn`3v3L@m zd?y33gu}4tfELcskYPCw$5WM_>C{H$Ij%D7#7_@qk%8|eq{i|ldQel3o0=QYN~=@c zq-AdOX-2y6C~ldi|Boc7B-9hNWQ4#qs|yRs?<|T8g~WKuDN;maB2}T?RZX_7tBS0 z@hv1&y@snlv6Cz;Hl{N5$Fa}#WL*1G8n1pcz$HXtTy^X^3EHDYJ||5f2S<$|?qQf{ z-AE+9*$=s|XOD`UdN-rwxJLA1Z9e*uI3MS(bHW<}ry%juja*wq3ioW*7v$;?i;ULG z;+!qU==9I8qH#y|lTW7#xp2)uVz5k`E_%>R9OlQ8%4YRN}rB$q`oZ zeL;r!%kEZe*3aR$+3~c-vPkqyJr)ZTv~Z|IrYL`|HF=^bjc%U3Mx{E_QBs^PdA-^Q zX*3m~;(>mytFIc(j*}sU2Up`KhxbygNn1o)f(+5zd@(#!yhW7g^;KjX{)FgG$YKpn zi`nj*;w)A@nD%IuA@A5>T(PGd&VMu+U#V$^z#W$%eI-UqRgPh;k0-cI7j@~xL!%(^ z<}!4?qm&((e1JLS9ih%Qq)_wrk2o*L82{EhLCBC}*dWXr{ZO}JJx!)~)3|i3yJ;KQ zrD=y3KaL>76uuE1bA4ns#WI&ZkqFZsEwPCS29kj|Yvoh1VYuZyVR0!KB~5 zz#Gr1(a%R#;Ki>Mm{|1>tQ&t188wKpchnB_S_$dhWzIqhrm{t&CQuLc!*oJp6vcD5 z!ZB%Y$Wt!@XLoBn*QEj9V2WxwjKn&SJ!?81)V&=a{;-g`EqI6?UHuKyM($%$ zhYlm5;sp9+@DX|@PZBqOu|a2a)M-W45t4OX4-!n|>B3QI=tb!-DkW6IH%163H{~~( zvA2~i4>D(#rw1{?pPRH;)|?*hdqU4_a6q-Y)aVMKf!S2|%b@glE(HD9jBB0qIE(dT z$=wW5>+*_cB*1nsWQQ%F-vaN5TG1;}gi0XV_VuXfUcDLpxG{!moGzqS&+bB9H+PZ` z*G)w~zHTPF2RW1B^VK+21%0I9{T|&?or{ZqjK)4z31~rv3pub(BJTbE&U5U+PWpExU*=$aB881XvR!X3E31@z|Vx-Cd&J-cIQC z+HT@9&Z#wG<|?iNy%HqM{ev1eTac(Sajf1fD|lAdEqM1m0rlHS5{V#jGQO-@G*s-d zD1oo}h*hSxCf3{(rLQtUh7w)m{zMQ3_{`__7$+jzBqRLEG#M@IH{~8NC*-&}3k_Oy zQFL?a1#-JV0j(5mB)@_r$=S>}NK`+}!U^5FcukNM&a-@mpJvtyd?u=i#D9z@dh5f{-OC@) zejQa}k=KWw5CI9L6IxooD-!Wp;pETuIFVJM75W&ejAQ)W(azmFNNwmE(oobSdgOWt zy(+Kdl(by1^0z}|)R}s;?&D={&w~!p1nU%$*C|_0@8>e&Y9w)S?wj!eiCr)~FdVwu#6W9FFCI7dd#hu^3&9`Me_Vslm;9&ek5TKS zn+A!&%;F8eI?v;cH-D2$y)JP1{d0jZCxgy9B*7lJXw$sht?c>9*)+d#j6gy1Uh6fh zm&Y4O3j6o<|5H9Bzj+XTFf^acbKik&hilOJ-FzK9YY&#nGzl)%fel@zgp0CAOXGh^{Pti7hGtJLzj6clQF+ zrmcjA`Z2z5{VPtoHj^$tl8sGsW#Rsl-FQaueq0)S6@O}Vz;{aI1)kwSqCGM-xSI@u zjK}d{fAKszK6f{D8)uJ{GP`h8mNPxv5rW%7kK^X6{b*4^Bo;5Zi{sy_;b(JSa4oTW z=$Ofyu%faBKH%9xr4!WHv0!?57rO3dNj(-d(4}UN(EN{&1gUn4?2OkbtXJYi%HF?h zZFpHt<+oqPt)a)ru;3ze@~9b3u4+b;lVoUUZX3R}T93%q>eC6I$05D)?X)7Gj9Lx7 zjUBEj;xV~P$a2lC^yx5poYXOrMMi~^Nf9Gh_of51F4YL^_GsWRmp$bDy{UqhJ$86$ z^(C^q_&a%f?*ej^>c%Ha4{>zl5#(&l*Wb^srG1A=skG$@I(H_cL(f0OO+6`C<&QJH zN$t4de@c+@$fqQImJSsgwhSAN_QRt(m(hqhMR?w-JRDYao=nXMqkZ`kXrTt+>*J4M z^Q?Wy`oac!QB#72q=(S=s)vwUyE~0e%*V~2#^cTV6R2XsW@=-%2wRwEk>?7Ih+AAs zk1YH^E)>4Q3$Engt)p5={eV8M%rVB6xs%8)kE@g%>%j6sdAM0)g=5hm;p_f$p~0-CYJxlFgz-xCv{_*^57#r_*;8QuK@Ybdr(ulUurO zC)FMvfTujy6*(@IAt&93;ij|#tZsjczT|6YCr1orhxi_p_K4B{G4E^QV!<*~<=H-$ zsqDnoU}{@^jXK{pVPEW`+4s*ebTMwGYG1VJ#bLwQ^*=}HKK*A@M&F8lEO<&4`_=wO z|M|c3EdM_L_q^Yg~3=a;seqFxNeMNQ#aZF4|a(8jo+q zw)?%Xf71#)w(bp1iP}zXomIw<=2dX&-xuQyy=Xi&w1Kl%ZOWl$RV*posg{CR-~51B~qR}ZJ%{d;iz(p0MB zKY&m47*H9vWc)xuNI&q}0@=mCNjB}q>k|8^PRB&5YJHVfCQqQUUFtaTYXbi1rQW(; zh2hPjo2^$qPNXSQ-;j!7Riepzn`vt47<^;XA$+<+henA%C#BB4ROxyiJ-y_T==_O0 z)bQkQ(*Mg}y|-Z!q%I980%sAI7xPq~ZRCLcHtxF&ea2i#n_l(kK^G z*4H*dbo#qH*1Mw$t4yborr(G0JGH5JVB1Fo(}&=9e$#OLEOC^0C5vt@bD~gPgtH%% zpdtBj)a8)}hL|VZt38gK%ITlDY5WSZ(pZ;{8+sg9l%K`9qnq%iA*J}sr-!1DBN|Nd zVhG;lIhd;S=i#FzWAWfx30Uu|CAxmF82gED!aB3FXvJtF`g_7cT(v=+Iy)StGj@oh zfnBeVgSH8_94bpa7Xh_#u)sr1v~dd`Yc1*A#X0!a(Piz~v`YGsXj?18YkLn93Eyh6 zQ=t)SFL9vX)jL~>l{fZGoQKbz9iWYTu4zc~S?nFyz|HwGmE;**5Y66Kj0Q_SBnp`? zk*97EEibJ>HU1~4?2JBgrffaB+kTWNI=@7wOU2+oYC2ueTaWv*^|^U&aWrO6D)l;I zN^86WD1I#gb-t_c4fogR-h_R)YRebIr(gc$I)gPa7>qdXp zX?OH zj#ZQ4z1o!ARwU0%HR#FX*F}!;p>!iv!z!uc=+vVRP{Td}zBw`F|JsSa`CIWzuZvE& zcZ)7oI-*h=d4dX4Mc01H3A*omYGt;s1nO}kP}W)l;-%UkaOC?fE}Ca?HTUM5Cd?1u zayO>@$NH1g=tM{;4u{jK^Pz%O!*{8}@F)5(tW`b&8%7*~<+VFN-yxIli;IBqLu%p5 ztKB&L$bBU|x*_vhUIm%R1A{`9|`{;uKqFMZG7^cVl=c>dAx z{G;RfN5}Jzj^`g8&p$eze{?+m=y?9o@%*FX`A5g|kB;Xb9nU{Ho_};a|LA!B(eeDF zI-dXkI-dX3_i@{NCkoDcjxsOq#<^h=1@8MAQO3ShQJ&s65sh0; z`jR%$#V?l8XN&5&2KQjH_TzWr*5buA#_mH_PD)g$nZpGvHEgw6xRstTu0{_E*OJ>3 za;?(Mx>RoTaQemH7tb>PiQ7x1ps_H6w5ukQ8CzQg=Uj&(n^XfV5^qBLHd|2FppT-k zt!bh~Z@&?7Z+GH8WiSpe8ALq=-$_eSGF`d*7Tr%WiB0(eyn^q|R)2edhIY1~fenX= z{9<2})!9hu1*NDXV?QZeIE?POCx_4fsG-wlhO`cZx8v}sH&DLkBKqLLWh@r<3~8B9 zMe>~^P{{cgNUl8un~sm-=7m2Nm~KBtURGP8Jul0!_c9$i;id{&;J6#VRXa&qHZ7%b z%n*BD-Np9{q~gmuE9sYEW~d}Hla8=2r%P|`K?}xzC$j^h$js_aQCqwRoqpjDb~yhM z$xCVTy-t@%-k)vgWzK3;GBK8#zPQD~#ZXQ>Ba_|;`aqS^PKlN+(x#uEEyZQk^U!ps z6w*3uC%)yGLcf?8BKLa|MCY}L8wjZ<>bE4R*gz@vj&Q(wa1alP^rl~a7U62ENWo6= z4K%rJGH&fH#?>!#u;Hl3H0fO`KDFtC=uA@p*VmDU$48&S-z*3F$F8JHmYv7rN5|sn3j5Kk z2z{FT;UvBJdn{R5!S}sYz7}b(dx(new20WzzRkuK){1UL>z~}gXJkqwnBzR zj}62TZ+_6{Mw4Lt!-Z_Yttyz2d6kA2IAB-eK+lLJv&*Z!*|wpUaB$!_c`zdhHl5Lf z-P4Q#+x+JHP-+u1l59k_U_1b@)#Wko}!QDs3S->>cq&+Q_?=DIz; z8@V5!-R+8tOT}1Z_ZA4(4#$zF*3qkquDHCY3EJakf`-%`EVW1q&pthtjTbqfk^bed zpN&|Y$YXQ$<=J(`ddhD>fP=^1!MD6Z;bFlMYS5d&<{ijkru!pVV^uwz+iMB$Yv$1B zcC#R6r6e=c(!>U%Pq0bvgJAN=b{urq954Jbh|RozmL1m9!}q&%;Av|njyyYw-QBqf zzPEkE?Q(jodUh3j65C5sp6_M@z7dclz6{&h_(N9l1M0lAo!4LFQ-`i1>gRrou28OK zk{;r8*|N_ubNroy;U^<-74628pudt$wP6qGJb zW@j~b;~lkwMefbp=z!Q7N?(iPZBzFUNA)B4R74xh=jExx&j{JnkyY6F`V$(~FdS|K z9AFb#{XnC(md0C3g05el;MFlt_z3Um!9iznQ<5RmO5mUKR|mGHZyuajl}$|ttHX^y zjyT8qCDt!@1RL)xI$v@tds=sqS#6s}C5~*tCprVLW>Y48fU3Y>;yQj&zl(k?(WHLf zUBqD4aJqKca@-daN7S;Wg22E2G&f zCxVA#B9kd+SmkIAu2T3;)MdAU*Tfd;`{op=_}RcX*YhG3iC&uZ%b0z=M^Rg$3X|gu z>2_f!bvW}Jk2*?l@iGjG7K_-GLz{40uNtG*v_yF+H=*2;?~{EzhMl z!e|xPaXS;n$*h47;R4FIM((tiEBLJ6&P?XMp_4PdbDK`@X|Z{60>tKo5_L64+Py{* zXRS$P6IC`~ovc!NyV9C%u$(ArGgt;6YR_TUdGomK9}QWUkoR{aV!-6ndZysiiKk>P zfa@|#;hn-OW|bb!!ZInnHMk1yshi`q^*`yln8Pr($$+^P`cjEoy=dN$PL!{(i|vwc zgq>s3X`@anDN%@o8{y;Fp6+5=^u&$+8kG-*Zh>sc*8SkGy@em!^)zo?F5A7Ql*8Z8 zqYmpUc-`!6a9VsDUQ)gt=J(VR?LTYrm2)Sl*V`KE)UCljkKRa|>i*D(i6!{!b{UY0 zwZ=uk#xT8c0jWN354Xc|p~A+NnXBnz^}BU=F|Q9uy||Z!MJI_ClLV1K`x8-Bn}a%J z4{*zh=c4LVG4gWbNNgTDn^*+pk{G@ZzF4h=B&QxHe(D}nN@^aC`DjSK=dXjODn>{x zvzgL`)@Zt$kl0zM(b6+x>8jtkRL|B1XjUDn8?-qkd-x(Up1YKzGM4H2u$FG{Nf}85}hrS|YWb zSQ)F};&Jsv`(Zm8GVp>tPr4$yyrPfSx6G%X9vwv2MyZl$TN51H&hJGmQ$f5Jd_Yy^xXs{8tjD8~sj@p4D zb?2hnrghxCa5a+hTbo){JRq~~sL+B%9_Y!>ezG#-i>Mn$Lh_L5cv6x(XYZ;&ee@RN z)4z0SR;VVuDgY$&D^F0cd^COhEER>nJVewxucGx)i@3`LA>6(5?a0t>7P-7pjvSZ1 zO~>rqAljy1E|}fdf}iei#F4(|k>1pi+|~Wf-2J4zVjN16Ab>d=1J1e)4wTV;E z!PnFAncdfsT<2`MtTg~`d@y4h1+uLF^%0nSDURP~#~#P37SXmsd16uR%kB)W$4X{7 zcuIl+9XjGHlb1=O9glUGC+A9v9Ce^Mzly5ASqp(m{_M!PRQh<^2FN$gN4cx=nbzxd zP}6Zyq}1dI;%iSJx9Cf>-7J(mZ4sl!J7wr2rDT3QdO)w)Vf^LZF)XdN17BB`0-Z4> z_|eb7sC%G-x_)}b%metnLXVGz-*G0Q5i|q`KMVkQ*M-E7*IB-aafLC?12}!jNPe%! zm-KMLQJlKkl--%0$_fNq(ZL*lxU<2Zt@Bf)hi;nURZ&7ZbL%0xk=N9E{YYn|F&od$ zU50HF4VsHoqBKZRLM176p1qFZCY30;M5YiWLuML&$M?m3eeS&<-+RBm^Uw2o zooB7R_u6}{_1b$qd%c(R(?ncZ7fc*(yWrO7i(nzzK-&a&fBA8el^z$O=)rwj0U zE=DWCnFJ&){JBtHFc}WVq(S-!84~R;h2Qvcuq`^W45;*XVsgdS(C@1xaCU_dGL8C$DhMt-e^j2GFm3m zEqCJ>#Y-$FezgHF?0hk?5jUcT^7j&%4gn@hp*O3n-oa+AW9g$4{jBiFF<>;;kDihE z#10vF!wQMCaSr!6lHgW7R@}Hukjt+`EAK|JyQb)}jD8IJ)o3M?d}AMTrT93BmKer_ zHctnBx&mw1Q)$t*ehb;^G>A+)WJ;f3|r#MS=6FTeY2{WhqIg+Sz}H5*dl_}p6$r2eIrBE&n1)5qhsjPhAj56 z)_F#_a|bNh=F2_{PGZe_BbZt75+qillWC3{Lp+y-llp7M@FAs}Y0Mf&j!2wiZ=27j zjv-O3?`(56_nscvbt9cjx*-XQNv~<|IVB>!X%wTqU_bM|rJAX_Hiu-gFX)bYs!V06 z8+p~A&Mt3#$w;+=Gs0QL@Ud%tXuRVJ@s~+~4Gk8h`$NbF+ zvM^>4k>V8doM(rz?>=VoRD%^hW$nm+y!W|q#6l)ezloCMzmvl%) zDk$HyBInm!qwN}BsgO29PuKM`mbPsqs^>9J%)$=}^LK;$7-ihjR|!&Y&9F>u9$Y&t zj66k0;y z4pSxTHfj~0z$8wbgRKUO(5QJIxwP1vDZjOc7E3sjJ--Ws=~{L0Dph1#FPNj_{{7TE zb}I0VjWDdHjMQcaXZ94aaV-|ZCx0RhaK&} zzNrjkWLjb7Rae|4ydO%ggyP%D!JrB8rN#1V{WH2UhHe9e$~RHNns+_Oko)4 z4D11WuM(_i*^A?DibJ`s6FC|!$^9ZaihEyZDwa=l#goTBq1?QYOm^8(d^Pb1rr&kr zS2Y_@^L-lN-&l8qTfhKj4EEWxrqLQx~S1!GhRT?W{i9C2o4s$fG@iXaY;Io zCQo_p>DrkPIldN@--e-nRSFSz=z{H@s(4CU5w@ishbajh8Ymoz<%$MitM&>vSigmI zZxODphB!HCQ-XTon+1L&2DiNp#kDr;@w)35My#V|9 zt@bQ(W!ZA%Zc)U>S2}e1n-(HHRtNq2-Z1M%E7Flkhap*4p2ja(22X1w+3{yYXx~wm zp4SwDphr?v|H?vkO#KyFHb@uVm5x9sg_rdHkOERCxe&No$*6DZ%&fIhVLYT`Xnt=j znR99f(fTCL_7tv0M&~Zg5MN6TPUMq$Wm9Q%NCE`fKBh~gZ6P)yjpS?;p|8%5h5B8p zxWdDi%`~eZS#js-y9Ew-r8Jp{CA83S6WU0@Qd4w%?@jc^+u*2U1+4BJLGE&pG93dj}xVpzRn9i3cLM-Giy zjK>pGXpUmkQdbM@WBKD#`&3~lk{cs zIIJ<)1G3jLNn2SbDbm)j44E$rA02w>r)YIr);$_G)rZlA&c0xA%!Cfvo(PX7F2*Q9 zK9*SbVbWW@hu#?CfDhM)lJtpuVwYV(?zoL*5B92%P+3!QbE-B>%iK-od(XmGYcA3# ziFxp0-E9(DR7lqfguI$=0h?ylGviEhX}QNS=E;q8-iA15qIq;OzU;HWIZGALd&p60 z)Cu4+^BZl-UB$CZn@hxVmGQo!7;EqKGS4$>Xj@z;dZok@casEm=KNNAG5k4+ zZ&u>;T~~z_9zvLVb_Gs3EDU#t*}$#wVnqGGe7s{@K-zV+@ly6u_OtyCVj0sw0**Rb zI8+DY+RhYGc+!`Mj-L+I?q#Iw#Ch_mRFQOhXk(*B9PBspBTp52s90Pp4cXDmSsnb6 zQNOwsn|g*qo4+SLez${d;p{dTsnR4RdK8g`3Gbnj&!tADHD>q@|HK_RS8R4nX#qJfO@P?~p3m6J9}kCoTg zfr%d^a80EL*?;yS+4`n|m)oERsgpx7)x(h8EIAsNdL1H6z9Bp?ZleM~epp?e$8f^; zvUwtL*i``3~1@0%$ zR7Dww_ewN>{xvwWX(n-L5ad%PG_f@sVrXoTGR<&{qtl*UBDd7iN{|0@j95@CD)8D=nt~;XUMP!K`eJ{8J%DDn2mTGz({SaWVE`wh`QWU zD*NdGId)1M_8BiH_lv($w`8ol18YJU9b!9M9dr2!rn%F$wBqF+JIfdv5YO+TU zF8IHqvjXjSrQe3ovj@IWkAYk4xtco!)M&cg1H8DPdQ`YOboqFcZm?Jr_KDVqrD1=XXyqBOUp zgu}~R?u7aZitwdciwq0mvD0Rcq7`>1gXsHCoG!iv6R*}ma!e<1SN1re7M; zMP8!wQzhKanqg?JJiJM21pgNnbd-lRJL|Lyd2hB7duB+%X@eQqB`*p^!gt{Dtw#3L zz$xaY;seI&B_Gc6Uzr zV^0?PdXOS3ZOA^M4kkNeaI4IC%yW@tRHuqU)-^X;?j}hFNDVEKUk!JzSCaF*BwBQ2 z84;RMPo}t*QvSnwx<61AR9QH{Jn8;vo<4{V?UR{Ib2T$UZ%PP3*w^eNSkwrwNa5l`(mB;PT z4#3rUjqCPzvD|onjMZ7jDQ;oNz1zV|cdsNwS;WAmJGnUb*aZxnPzingA~eLPlGUx! z#U1K7w0qwtuzq_7W4?9LptCoqInR+TiPE78J&n|HP&jyV1o*Tf*6_{$08V=&%eB3; zlD%EvMYB)KQCY=@7+TW_9)^ivJaiN^#y65x&vV)4>s+2`*dV%aZab>oJ40ISrMUWj z;{2q!x3FDRnyvS+q5Y4J(q{MDG;i@1eE7K*f;0#7UmvL^W&X;H&}jqC*o`ivAh7~l z7B`S#rV}vSWef~;Hs!iB@>mm%-5k+P2bj-}m+|YoPPl*j5hQ&R=I=Pzh&@9D_2P!? z;Y^-hHQA906Y#ZC0yC~(OLc*Q<@+FaNivKWRRoJY*8OALM5wow&Tg0o7jQlC zdVPa#y;Vgkd~|t z{^sZNyG%cSjYGnTQ?fA3tc!(ZcgKLSpvGHJKqajmIv1M@bC`J{&bZRk4y?~pQmw3r z3w&hoykZ2HCM3Z1Ej`qI+&*HRbcM~9R|m<<(}_u!;G5fZGicH@gbL#(a`sFk*4#IN z`{va+q<1w+d=A4G)rMqFcNne}eD5tbT1hjetKi0B9^8u#rZ>K}(-8x<(5B0yeXCZI z@U}L3EW!bo6_qkm0$)+JEM1gZJ`0a`tRp(_FVK(eQ8;$Z3bIWm9m2PV;r4tFR6F8J zf(>JMPc^TT=P#$S8C8*VU19@jH5Sm2jBMhlCXPCl(e&_xX^>j#4PWN8fX0quY%5Qo zL6Z`B39@xG!c*{FdCxisIUS838-_AX_K)b4{19?ngM(@X@{pr1j``d~5;?{n($h<@ zCrl7~xTB7r^GZpiWCrARIimD}Y*?iK3Ibb?(FgVmF!JChs^;>5wqH$xwZSIfBYP6& zJ0B#=%9JauzRaMdb%!DC(+sFGv&V24glERp(UHmeP&DH#>gWWM(t&6?e7+g13Xh{- z(&mG3cM^X4DvwQTBAH>*D*PzXYI1O66u?MROy+%IR{9C9L2OfnuesU8)#`U50UETd zfF)XotH@m)W%L>JxcBq?`3VpX?a6?KJG$d}Jc1z#UH&5P?xVHv0pk@kAU8PD3 zbCRHb(o~+jjR3Q~JcMecX_JMCgRpTF3Th@~P?IYwup(s&&YhWvQNp)L z)1zGQt)4`lesEwkY(#PEG7;DjzaMsO*}&|&{f646?ZZWlk#s_rD`@526Rh2MTx=NVM zdbAUry3FBT-Fu>@QA0LqYjUS9O(uQ0N9p+00zZBzm`1-`jSly1fpo7WJKUm}-IWjM zD95LmDm{b#zW6LzH6XxbI;n`0E}fu1^Xh-)=d|yAC90_tAvr!9Y=UaZ)4`D}ThNEL zn{?nw&NwItOvHCzPLsKwQeZO#F+sK)sg)!&2BncnaxVWmE@PP+Lcx0q-K{@`e7g3M zwB*&W32k37LV&5JXZH#uZfL;fkT_no0EekcMToC$7mLU0ipXhaRXQZymhXJA6dl@M zQ=gg~GXJ3|T5HuI1nk9CZ$Hx6*6|e5s1g-kpy|N55ve%Aon z4;gXmeT(VV)uGhc#G8p7c9+!;4TJbE6RCy;1yxg3Jlc4Q+SSG)C#sM5UKN7cm}C%G zKNjC@qe1tgNxu3L+Ow&N@s>V9%HADe;&*Xiki|^ex;!7WS53zmPnO~c6;&8rqKH1K zj&Mq~6ij|^CHW7=5k*f2s#KmsD&2F5*rhPcmX5%3{fl&H^9or1VFdCAwUC&-vBdHk zUr-CglD-$QqT$I3s2sE%%|M6tM_%Gd-o8dWRi=TEKyTE?wQ>%`t{}~>`$;tGjU{=7 z+4V)ouYwynx!;MPYeuo-_n>vbBX4_UMAAEpRQVSla9b$ z*dXc(ryh+XdAsA_nNK(At%@P-dXjkLxI654EKPhmD~a7PRXlk%ot}}IOq`c|Br#KR z$i>TXP}?62r9vI3w6BJ7T`-ge=iX<}Z#JMuK1I=OpXB(j4Mn)To&L}kR|Qt_fl&B( zAAVY-NaOwXz^X?^a4r75;Jhn`hmVVJ?d->Tm@TLw&rw&SJnJzuSPrAb983(c=nA#4rJ~k_&_v^JiTg8i#jbZPA7mW8|%TC-0Pfc=3%s#^T1sC z${dnyrZuLri2d+{Aj^m8Djof&RqA-d2&|NpS&+`W;E_CBdN<4 z(@8o`MDF2ewuv}1-)1aCvHBC0b(X(T@wG7|C|jJ=_!|Ccety;~|2|hgfA3}&P`%4r z=wu;5Jk5Fq$8Y%uM*oN}am`gw2lTjSxh!^Ov7o6E)xZhoVGymIk$_F6Ca*W1f?jmzeBF6-BC^xW`+ z#`rHZ|HJQVFv-B%$7PMLi-+qPCqF+IPrr>`o(BK(ZjGmx&t`)SPMbG7P5J-hUFYKG zH1Sugf9cZIb&WuH@0m7w{@~I570-A1TwVVrlg_`V|1O8>ugC>WySh&P#X|fs z$v>#pZFKVe-&->JS1lPCd~a#oU&iFR(bLIejkk}Nvx~3q-vrkFx76Ro{i#vwU4E@m zKbQ6=%^$)#ZT0i=aoOPWLr4GX(*6skKbH2V>i&)A`_lfUvtOz1|B?6m9Q@S!|H>*b zwg1InjDF1Z|Bdx~r$1HyZ=8Sb@K19({#U*Ijc3ab?b83(9Pa*qv9G@k>>upwU(o-j vef^F6zuMQYsQ%Nw{yWWI?dw-m|7l-8Y5tvk{Uh~vafkj`fIm6&?|uG1bNPr7 literal 0 HcmV?d00001 diff --git a/test/data/model_2.2.x/version b/test/data/model_2.2.x/version new file mode 100644 index 000000000..e3a4f1933 --- /dev/null +++ b/test/data/model_2.2.x/version @@ -0,0 +1 @@ +2.2.0 \ No newline at end of file diff --git a/test/data/model_2.1.x/vocab.src.0.json b/test/data/model_2.2.x/vocab.src.0.json similarity index 100% rename from test/data/model_2.1.x/vocab.src.0.json rename to test/data/model_2.2.x/vocab.src.0.json diff --git a/test/data/model_2.1.x/vocab.trg.0.json b/test/data/model_2.2.x/vocab.trg.0.json similarity index 100% rename from test/data/model_2.1.x/vocab.trg.0.json rename to test/data/model_2.2.x/vocab.trg.0.json diff --git a/test/integration/test_backwards_compatibility.py b/test/integration/test_backwards_compatibility.py index 7dd15b997..d663fd09d 100644 --- a/test/integration/test_backwards_compatibility.py +++ b/test/integration/test_backwards_compatibility.py @@ -43,8 +43,8 @@ def test_backwards_compatibility(): output_file = os.path.join(work_dir, "out") params = """{sockeye} --use-cpu --models {model} --input {input} --output {output} """.format( sockeye=sockeye.translate.__file__, - model="test/data/model_2.1.x", - input="test/data/model_2.1.x/model_input", + model="test/data/model_2.2.x", + input="test/data/model_2.2.x/model_input", output=output_file ) logger.info("Translating with params %s", params)