Skip to content

Commit

Permalink
Alternative, faster greedy search (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored May 30, 2021
1 parent ba8f849 commit ef908e3
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 130 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [2.3.17]
### Added
- Added an alternative, faster implementation of greedy search. The '--greedy' flag to `sockeye.translate` will enable it. This implementation does not support hypothesis scores, batch decoding, or lexical constraints."

## [2.3.16]

### Added
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.3.16'
__version__ = '2.3.17'
6 changes: 6 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,12 @@ def add_inference_args(params):
type=int_greater_or_equal(1),
default=5,
help='Size of the beam. Default: %(default)s.')
decode_params.add_argument('--greedy', '-g',
action="store_true",
default=False,
help='Enables an alternative, faster greedy decoding implementation. It does not '
'support batch decoding, ensembles, or lexical constraints, and hypothesis scores '
'are not normalized. Default: %(default)s.')

decode_params.add_argument('--beam-search-stop',
choices=[C.BEAM_SEARCH_STOP_ALL, C.BEAM_SEARCH_STOP_FIRST],
Expand Down
331 changes: 241 additions & 90 deletions sockeye/beam_search.py

Large diffs are not rendered by default.

49 changes: 26 additions & 23 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import lexicon
from . import utils
from . import vocab
from .beam_search import get_beam_search, CandidateScorer
from .beam_search import get_search_algorithm, CandidateScorer, GreedySearch
from .model import SockeyeModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -708,7 +708,8 @@ def __init__(self,
max_input_length: Optional[int] = None,
max_output_length: Optional[int] = None,
softmax_temperature: Optional[float] = None,
prevent_unk: bool = False) -> None:
prevent_unk: bool = False,
greedy: bool = False) -> None:
self.context = context
self.dtype = C.DTYPE_FP32 if models[0].dtype == C.DTYPE_INT8 else models[0].dtype
self._scorer = scorer
Expand Down Expand Up @@ -741,7 +742,7 @@ def __init__(self,
utils.check_condition(self.beam_search_stop == C.BEAM_SEARCH_STOP_ALL,
"nbest_size > 1 requires beam_search_stop to be set to 'all'")

self._beam_search = get_beam_search(
self._search = get_search_algorithm(
models=self.models,
beam_size=self.beam_size,
context=self.context,
Expand All @@ -755,22 +756,24 @@ def __init__(self,
avoid_list=avoid_list,
hybridize=hybridize,
softmax_temperature=softmax_temperature,
prevent_unk=prevent_unk)
prevent_unk=prevent_unk,
greedy=greedy)

self._concat_translations = partial(_concat_nbest_translations if self.nbest_size > 1 else _concat_translations,
stop_ids=self.stop_ids,
scorer=self._scorer) # type: Callable

logger.info("Translator (%d model(s) beam_size=%d beam_search_stop=%s max_input_length=%s "
logger.info("Translator (%d model(s) beam_size=%d algorithm=%s, beam_search_stop=%s max_input_length=%s "
"nbest_size=%s ensemble_mode=%s max_batch_size=%d avoiding=%d dtype=%s softmax_temperature=%s)",
len(self.models),
self.beam_size,
"GreedySearch" if isinstance(self._search, GreedySearch) else "BeamSearch",
self.beam_search_stop,
self.max_input_length,
self.nbest_size,
"None" if len(self.models) == 1 else ensemble_mode,
self.max_batch_size,
0 if self._beam_search.global_avoid_trie is None else len(self._beam_search.global_avoid_trie),
0 if self._search.global_avoid_trie is None else len(self._search.global_avoid_trie),
self.dtype,
softmax_temperature)

Expand Down Expand Up @@ -1065,30 +1068,30 @@ def _translate_nd(self,
raw_avoid_list: List[Optional[constrained.RawConstraintList]],
max_output_lengths: mx.nd.NDArray) -> List[Translation]:
"""
Translates source of source_length.
Translates source of source_length and returns list of Translations.
:param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
:param source_length: Valid source lengths.
:param restrict_lexicon: Lexicon to use for vocabulary restriction.
:param raw_constraints: A list of optional constraint lists.
:return: Sequence of translations.
:return: List of translations.
"""
return self._get_best_from_beam(*self._beam_search(source,
source_length,
restrict_lexicon,
raw_constraints,
raw_avoid_list,
max_output_lengths))

def _get_best_from_beam(self,
best_hyp_indices: np.ndarray,
best_word_indices: np.ndarray,
seq_scores: np.ndarray,
lengths: np.ndarray,
estimated_reference_lengths: Optional[mx.nd.NDArray] = None,
constraints: List[Optional[constrained.ConstrainedHypothesis]] = [],
beam_histories: Optional[List[BeamHistory]] = None) -> List[Translation]:
return self._get_best_translations(*self._search(source,
source_length,
restrict_lexicon,
raw_constraints,
raw_avoid_list,
max_output_lengths))

def _get_best_translations(self,
best_hyp_indices: np.ndarray,
best_word_indices: np.ndarray,
seq_scores: np.ndarray,
lengths: np.ndarray,
estimated_reference_lengths: Optional[mx.nd.NDArray] = None,
constraints: List[Optional[constrained.ConstrainedHypothesis]] = [],
beam_histories: Optional[List[BeamHistory]] = None) -> List[Translation]:
"""
Return the nbest (aka n top) entries from the n-best list.
Expand Down
3 changes: 2 additions & 1 deletion sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def run_translate(args: argparse.Namespace):
max_output_length=args.max_output_length,
hybridize=hybridize,
softmax_temperature=args.softmax_temperature,
prevent_unk=args.prevent_unk)
prevent_unk=args.prevent_unk,
greedy=args.greedy)
read_and_translate(translator=translator,
output_handler=output_handler,
chunk_size=args.chunk_size,
Expand Down
8 changes: 5 additions & 3 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def check_train_translate(train_params: str,
seed=seed)

# Test equivalence of batch decoding
translate_params_batch = translate_params + " --batch-size 2"
test_translate_equivalence(data, translate_params_batch, compare_output)
if 'greedy' not in translate_params:
translate_params_batch = translate_params + " --batch-size 2"
test_translate_equivalence(data, translate_params_batch, compare_output)

# Run translate with restrict-lexicon
data = run_translate_restrict(data, translate_params)
Expand All @@ -61,7 +62,8 @@ def check_train_translate(train_params: str,
# Only run scoring under these conditions. Why?
# - translate splits up too-long sentences and translates them in sequence, invalidating the score, so skip that
# - scoring requires valid translation output to compare against
if '--max-input-length' not in translate_params and _translate_output_is_valid(data['test_outputs']):
if '--max-input-length' not in translate_params and _translate_output_is_valid(data['test_outputs']) \
and 'greedy' not in translate_params:
test_scoring(data, translate_params, compare_output)

# Test correct prediction of target factors if enabled
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
" --weight-init-scale=3.0 --weight-init-xavier-factor-type=avg"
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 0"
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01",
"--beam-size 1",
"--beam-size 1 --greedy",
True, 0, 0),
# Basic transformer with source factor, beam-search-stop first decoding
# Basic transformer with source and target factors, beam-search-stop first decoding
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
" --transformer-feed-forward-num-hidden 16"
Expand Down
10 changes: 10 additions & 0 deletions test/system/test_seq_copy_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@
False,
1.02,
0.98),
("greedy",
"--encoder transformer --decoder transformer"
" --max-updates 4000"
" --num-layers 2 --transformer-attention-heads 4 --transformer-model-size 32"
" --transformer-feed-forward-num-hidden 64 --num-embed 32"
" --batch-size 16 --batch-type sentence" + COMMON_TRAINING_PARAMS,
"--beam-size 1 --greedy",
False,
1.02,
0.98),
("Copy:transformer:transformer:length_task_learned",
"--encoder transformer --decoder transformer"
" --max-updates 4000"
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_model_parameters(test_params, expected_params):
output=None,
checkpoints=None,
models=['model'],
greedy=False,
beam_size=5,
nbest_size=1,
batch_size=1,
Expand Down
33 changes: 32 additions & 1 deletion test/unit/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,35 @@ def test_topk_func(batch_size, beam_size, target_vocab_size):
assert np.allclose(mx_values, np_values)


@pytest.mark.parametrize("target_vocab_size", [2, 10, 500, 1024])
def test_greedytop1(target_vocab_size):
batch_size = 1
beam_size = 1
target_vocab_size = 50
# Random model scores. Shape: (batch_size * beam_size, target_vocab_size)
scores = mx.nd.random.uniform(0, 1, (batch_size * beam_size, target_vocab_size))
expected_hyp_index, expected_word_index, expected_value = numpy_topk(scores, k=beam_size, offset=None)
assert expected_hyp_index[0] == 0
assert expected_value.shape == (1, 1)

greedy_top1 = sockeye.beam_search.GreedyTop1()
greedy_top1.initialize()

best_word_index = greedy_top1(scores, None, None)
best_word_index = best_word_index.asnumpy()
assert best_word_index.shape == (1, 1)
assert best_word_index[0, 0] == expected_word_index[0]

target_factors = mx.nd.ones((1, 1), dtype='int32')
best_word_index_with_factors = greedy_top1(scores, None, target_factors)
best_word_index_with_factors = best_word_index_with_factors.asnumpy()
assert best_word_index_with_factors.shape == (1, 2)
assert best_word_index_with_factors[0, 0] == expected_word_index[0]
assert best_word_index_with_factors[0, 1] == target_factors.asscalar()




@pytest.mark.parametrize("batch_size, beam_size, target_vocab_size, top_n",
[(1, 5, 200, 0),
(5, 5, 200, 0),
Expand Down Expand Up @@ -249,6 +278,7 @@ def test_update_scores():
assert (scores[1] == np.array([1.] + pad_dist[1].asnumpy().tolist())).all() # 2 finished, force pad, keep score
assert (scores[2] == (1. + target_dists[2]).asnumpy()).all() # 3 scores + previous scores


def test_prevent_unk_update_scores():
vocab_size = 10
batch_beam_size = 3
Expand Down Expand Up @@ -280,6 +310,7 @@ def test_prevent_unk_update_scores():
assert scores[2, C.UNK_ID] == np.inf # 3 scores of <unk> should be np.inf
assert (scores[2] == (1. + target_dists[2] + unk_dist[2]).asnumpy()).all() # 3 scores + previous scores


class _TestInference(sockeye.beam_search._Inference):

def __init__(self, output_vocab_size: int):
Expand Down Expand Up @@ -382,4 +413,4 @@ def test_beam_search():
assert inference.states[1] == max_length

print(best_hyp_indices)
print(best_word_indices)
print(best_word_indices)
18 changes: 9 additions & 9 deletions test/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_get_best_word_indices_for_kth_hypotheses():
[([[], [], [], []], [None, None], np.array([0, 2], dtype='int32'), np.array([[1, 1, 1], [3, 3, 3]], dtype='int32')),
([[[1]], [], [[3]], []], [None, None], np.array([1, 3], dtype='int32'), np.array([[1, 0, 0], [3, 2, 2]], dtype='int32'))
])
def test_get_best_from_beam(raw_constraints, beam_histories, expected_best_ids, expected_best_indices):
def test_get_best_translations(raw_constraints, beam_histories, expected_best_ids, expected_best_indices):
best_hyp_indices = np.array([[0, 1, 0, 1],
[0, 1, 1, 0],
[2, 3, 2, 3],
Expand Down Expand Up @@ -421,14 +421,14 @@ def test_get_best_from_beam(raw_constraints, beam_histories, expected_best_ids,

constraints = [sockeye.lexical_constraints.ConstrainedHypothesis(rc, _EOS) for rc in raw_constraints]

actual_result = sockeye.inference.Translator._get_best_from_beam(translator,
best_hyp_indices,
best_word_indices,
seq_scores,
lengths,
None,
constraints,
beam_histories)
actual_result = sockeye.inference.Translator._get_best_translations(translator,
best_hyp_indices,
best_word_indices,
seq_scores,
lengths,
None,
constraints,
beam_histories)

for expected_translation, actual_translation in zip(expected_result, actual_result):
assert expected_translation.target_ids == actual_translation.target_ids
Expand Down

0 comments on commit ef908e3

Please sign in to comment.