diff --git a/ctcdecode/__init__.py b/ctcdecode/__init__.py index 0f6b3f1a..9f554e1d 100644 --- a/ctcdecode/__init__.py +++ b/ctcdecode/__init__.py @@ -4,7 +4,9 @@ class CTCBeamDecoder(object): def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, - num_processes=4, blank_id=0, log_probs_input=False): + num_processes=4, blank_id=0, log_probs_input=False, + space_symbol=" "): + self.cutoff_top_n = cutoff_top_n self._beam_width = beam_width self._scorer = None @@ -12,10 +14,12 @@ def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cu self._labels = ''.join(labels).encode() self._num_labels = len(labels) self._blank_id = blank_id + self._space_symbol = space_symbol self._log_probs = 1 if log_probs_input else 0 + if model_path: self._scorer = ctc_decode.paddle_get_scorer(alpha, beta, model_path.encode(), self._labels, - self._num_labels) + self._num_labels, self._space_symbol.encode()) self._cutoff_prob = cutoff_prob def decode(self, probs, seq_lens=None): @@ -33,10 +37,11 @@ def decode(self, probs, seq_lens=None): if self._scorer: ctc_decode.paddle_beam_decode_lm(probs, seq_lens, self._labels, self._num_labels, self._beam_width, self._num_processes, self._cutoff_prob, self.cutoff_top_n, self._blank_id, - self._log_probs ,self._scorer, output, timesteps, scores, out_seq_len) + self._space_symbol.encode(), + self._log_probs, self._scorer, output, timesteps, scores, out_seq_len) else: ctc_decode.paddle_beam_decode(probs, seq_lens, self._labels, self._num_labels, self._beam_width, self._num_processes, - self._cutoff_prob, self.cutoff_top_n, self._blank_id, self._log_probs, + self._cutoff_prob, self.cutoff_top_n, self._blank_id, self._space_symbol.encode(), self._log_probs, output, timesteps, scores, out_seq_len) return output, scores, timesteps, out_seq_len diff --git a/ctcdecode/src/binding.cpp b/ctcdecode/src/binding.cpp index 4ba3a08c..4fb11e8c 100644 --- a/ctcdecode/src/binding.cpp +++ b/ctcdecode/src/binding.cpp @@ -31,13 +31,13 @@ int beam_decode(at::Tensor th_probs, double cutoff_prob, size_t cutoff_top_n, size_t blank_id, + const std::string &space_symbol, bool log_input, void *scorer, at::Tensor th_output, at::Tensor th_timesteps, at::Tensor th_scores, - at::Tensor th_out_length) -{ + at::Tensor th_out_length){ std::vector new_vocab; utf8_to_utf8_char_vec(labels, new_vocab); Scorer *ext_scorer = NULL; @@ -66,7 +66,7 @@ int beam_decode(at::Tensor th_probs, } std::vector>> batch_results = - ctc_beam_search_decoder_batch(inputs, new_vocab, beam_size, num_processes, cutoff_prob, cutoff_top_n, blank_id, log_input, ext_scorer); + ctc_beam_search_decoder_batch(inputs, new_vocab, beam_size, num_processes, cutoff_prob, cutoff_top_n, blank_id, space_symbol, log_input, ext_scorer); auto outputs_accessor = th_output.accessor(); auto timesteps_accessor = th_timesteps.accessor(); auto scores_accessor = th_scores.accessor(); @@ -89,7 +89,9 @@ int beam_decode(at::Tensor th_probs, } } return 1; -} + } + + int paddle_beam_decode(at::Tensor th_probs, at::Tensor th_seq_lens, @@ -100,15 +102,18 @@ int paddle_beam_decode(at::Tensor th_probs, double cutoff_prob, size_t cutoff_top_n, size_t blank_id, + const char* space_symbol, int log_input, at::Tensor th_output, at::Tensor th_timesteps, at::Tensor th_scores, at::Tensor th_out_length){ + std::string space_symbol_string(space_symbol); return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes, - cutoff_prob, cutoff_top_n, blank_id, log_input, NULL, th_output, th_timesteps, th_scores, th_out_length); -} + cutoff_prob, cutoff_top_n, blank_id, space_symbol_string, + log_input, NULL, th_output, th_timesteps, th_scores, th_out_length); + } int paddle_beam_decode_lm(at::Tensor th_probs, at::Tensor th_seq_lens, @@ -119,28 +124,32 @@ int paddle_beam_decode_lm(at::Tensor th_probs, double cutoff_prob, size_t cutoff_top_n, size_t blank_id, - int log_input, - void *scorer, + const char* space_symbol, + bool log_input, + int *scorer, at::Tensor th_output, at::Tensor th_timesteps, at::Tensor th_scores, at::Tensor th_out_length){ +std::string space_symbol_string(space_symbol); return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes, - cutoff_prob, cutoff_top_n, blank_id, log_input, scorer, th_output, th_timesteps, th_scores, th_out_length); -} - + cutoff_prob, cutoff_top_n, blank_id, space_symbol_string, log_input, scorer, th_output, th_timesteps, th_scores, th_out_length); + } void* paddle_get_scorer(double alpha, double beta, const char* lm_path, const char* labels, - int vocab_size) { + int vocab_size, + const char* space_symbol) { std::vector new_vocab; utf8_to_utf8_char_vec(labels, new_vocab); - Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab); +// Create a string object from the char* space_symbol +std::string space_symbol_string(space_symbol); + Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab, space_symbol_string); return static_cast(scorer); -} + } int is_character_based(void *scorer){ Scorer *ext_scorer = static_cast(scorer); diff --git a/ctcdecode/src/binding.h b/ctcdecode/src/binding.h index 61a00302..ee6272d0 100644 --- a/ctcdecode/src/binding.h +++ b/ctcdecode/src/binding.h @@ -7,6 +7,7 @@ int paddle_beam_decode(THFloatTensor *th_probs, double cutoff_prob, size_t cutoff_top_n, size_t blank_id, + const char* space_symbol, int log_input, THIntTensor *th_output, THIntTensor *th_timesteps, @@ -22,7 +23,8 @@ int paddle_beam_decode_lm(THFloatTensor *th_probs, double cutoff_prob, size_t cutoff_top_n, size_t blank_id, - bool log_input, + const char* space_symbol, + bool log_input, int *scorer, THIntTensor *th_output, THIntTensor *th_timesteps, @@ -33,7 +35,8 @@ void* paddle_get_scorer(double alpha, double beta, const char* lm_path, const char* labels, - int vocab_size); + int vocab_size, + const char* space_symbol); int is_character_based(void *scorer); size_t get_max_order(void *scorer); diff --git a/ctcdecode/src/ctc_beam_search_decoder.cpp b/ctcdecode/src/ctc_beam_search_decoder.cpp index 5fad1807..133dd08a 100644 --- a/ctcdecode/src/ctc_beam_search_decoder.cpp +++ b/ctcdecode/src/ctc_beam_search_decoder.cpp @@ -21,6 +21,7 @@ std::vector> ctc_beam_search_decoder( double cutoff_prob, size_t cutoff_top_n, size_t blank_id, + const std::string &space_symbol, int log_input, Scorer *ext_scorer) { // dimension check @@ -36,7 +37,9 @@ std::vector> ctc_beam_search_decoder( // size_t blank_id = vocabulary.size(); // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); + // Changed by Gideon from the blank symbol " " to a custom symbol specified as argument + auto it = std::find(vocabulary.begin(), vocabulary.end(), space_symbol); + //auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it - vocabulary.begin(); // if no space in vocabulary if ((size_t)space_id >= vocabulary.size()) { @@ -176,7 +179,7 @@ std::vector> ctc_beam_search_decoder( std::vector timesteps; prefixes[i]->get_path_vec(output, timesteps); auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); + auto words = ext_scorer->split_labels(output, space_symbol); // remove word insert approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; // remove language model weight: @@ -198,6 +201,7 @@ ctc_beam_search_decoder_batch( double cutoff_prob, size_t cutoff_top_n, size_t blank_id, + const std::string &space_symbol, int log_input, Scorer *ext_scorer) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); @@ -216,6 +220,7 @@ ctc_beam_search_decoder_batch( cutoff_prob, cutoff_top_n, blank_id, + space_symbol, log_input, ext_scorer)); } diff --git a/ctcdecode/src/ctc_beam_search_decoder.h b/ctcdecode/src/ctc_beam_search_decoder.h index ce3048d0..8bb1210f 100644 --- a/ctcdecode/src/ctc_beam_search_decoder.h +++ b/ctcdecode/src/ctc_beam_search_decoder.h @@ -25,6 +25,8 @@ * in desending order. */ +const std::string DEFAULT_SPACE_SYMBOL = std::string(" "); + std::vector> ctc_beam_search_decoder( const std::vector> &probs_seq, const std::vector &vocabulary, @@ -32,6 +34,7 @@ std::vector> ctc_beam_search_decoder( double cutoff_prob = 1.0, size_t cutoff_top_n = 40, size_t blank_id = 0, + const std::string &space_symbol = DEFAULT_SPACE_SYMBOL, int log_input = 0, Scorer *ext_scorer = nullptr); @@ -45,6 +48,7 @@ std::vector> ctc_beam_search_decoder( * num_processes: Number of threads for beam search. * cutoff_prob: Cutoff probability for pruning. * cutoff_top_n: Cutoff number for pruning. + * space_symbol: The symbol used to indicate spaces, default is " ". * ext_scorer: External scorer to evaluate a prefix, which consists of * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. @@ -61,6 +65,7 @@ ctc_beam_search_decoder_batch( double cutoff_prob = 1.0, size_t cutoff_top_n = 40, size_t blank_id = 0, + const std::string &space_symbol = DEFAULT_SPACE_SYMBOL, int log_input = 0, Scorer *ext_scorer = nullptr); diff --git a/ctcdecode/src/decoder_utils.cpp b/ctcdecode/src/decoder_utils.cpp index dab6ceb4..c6dcafc8 100644 --- a/ctcdecode/src/decoder_utils.cpp +++ b/ctcdecode/src/decoder_utils.cpp @@ -155,7 +155,8 @@ bool add_word_to_dictionary( std::vector int_word; for (auto &c : characters) { - if (c == " ") { + // if (c == " ") { + if (c == "|") { // Gideon: replaced the space symbol " " => "|" int_word.push_back(SPACE_ID); } else { auto int_c = char_map.find(c); diff --git a/ctcdecode/src/scorer.cpp b/ctcdecode/src/scorer.cpp index f55e7a0b..df545153 100644 --- a/ctcdecode/src/scorer.cpp +++ b/ctcdecode/src/scorer.cpp @@ -16,7 +16,8 @@ using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path, - const std::vector& vocab_list) { + const std::vector& vocab_list, + const std::string &space_symbol) { this->alpha = alpha; this->beta = beta; @@ -28,7 +29,7 @@ Scorer::Scorer(double alpha, dict_size_ = 0; SPACE_ID_ = -1; - setup(lm_path, vocab_list); + setup(lm_path, vocab_list, space_symbol); } Scorer::~Scorer() { @@ -41,11 +42,12 @@ Scorer::~Scorer() { } void Scorer::setup(const std::string& lm_path, - const std::vector& vocab_list) { + const std::vector& vocab_list, + const std::string &space_symbol) { // load language model load_lm(lm_path); // set char map for scorer - set_char_map(vocab_list); + set_char_map(vocab_list, space_symbol); // fill the dictionary for FST if (!is_character_based()) { fill_dictionary(true); @@ -79,10 +81,14 @@ double Scorer::get_log_cond_prob(const std::vector& words) { model->NullContextWrite(&state); for (size_t i = 0; i < words.size(); ++i) { lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV if (word_index == 0) { return OOV_SCORE; } + // Gideon: Alternatively, comment out above (but in fact, it doesn't seem to work better) + // Rather than using hard-code OOV score, assign the language model probability to the OOV words. + // See: https://github.com/parlance/ctcdecode/issues/62 cond_prob = model->BaseScore(&state, word_index, &out_state); tmp_state = state; state = out_state; @@ -132,7 +138,7 @@ std::string Scorer::vec2str(const std::vector& input) { return word; } -std::vector Scorer::split_labels(const std::vector& labels) { +std::vector Scorer::split_labels(const std::vector& labels, const std::string &space_symbol) { if (labels.empty()) return {}; std::string s = vec2str(labels); @@ -140,18 +146,20 @@ std::vector Scorer::split_labels(const std::vector& labels) { if (is_character_based_) { words = split_utf8_str(s); } else { - words = split_str(s, " "); + // words = split_str(s, " "); + words = split_str(s, space_symbol); //Gideon: replaced the space character from " " to a custom string } return words; } -void Scorer::set_char_map(const std::vector& char_list) { +void Scorer::set_char_map(const std::vector& char_list, const std::string &space_symbol) { char_list_ = char_list; char_map_.clear(); for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == " ") { - SPACE_ID_ = i; + //if (char_list_[i] == " ") { + if (char_list_[i] == space_symbol) { //Gideon: replaced the space character from " " to a custom string + SPACE_ID_ = i; } // The initial state of FST is state 0, hence the index of chars in // the FST should start from 1 to avoid the conflict with the initial diff --git a/ctcdecode/src/scorer.h b/ctcdecode/src/scorer.h index 5ebc719c..84c62815 100644 --- a/ctcdecode/src/scorer.h +++ b/ctcdecode/src/scorer.h @@ -43,7 +43,8 @@ class Scorer { Scorer(double alpha, double beta, const std::string &lm_path, - const std::vector &vocabulary); + const std::vector &vocabulary, + const std::string &space_symbol); ~Scorer(); double get_log_cond_prob(const std::vector &words); @@ -67,7 +68,7 @@ class Scorer { // trransform the labels in index to the vector of words (word based lm) or // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); + std::vector split_labels(const std::vector &labels, const std::string &space_symbol); // language model weight double alpha; @@ -80,7 +81,8 @@ class Scorer { protected: // necessary setup: load language model, set char map, fill FST's dictionary void setup(const std::string &lm_path, - const std::vector &vocab_list); + const std::vector &vocab_list, + const std::string &space_symbo); // load language model from given path void load_lm(const std::string &lm_path); @@ -89,7 +91,7 @@ class Scorer { void fill_dictionary(bool add_space); // set char map - void set_char_map(const std::vector &char_list); + void set_char_map(const std::vector &char_list, const std::string &space_symbol); double get_log_prob(const std::vector &words);