diff --git a/README.md b/README.md index e7a91929..7e17b0fb 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,16 @@ git clone --recursive https://github.com/parlance/ctcdecode.git cd ctcdecode && pip install . ``` +For faster installation use (replace `` with the number of CPUs available): + +```bash +# get the code +git clone --recursive https://github.com/parlance/ctcdecode.git +cd ctcdecode +MAX_JOBS= python3 setup.py build +python3 setup.py install +``` + ## How to Use ```python @@ -32,7 +42,8 @@ decoder = CTCBeamDecoder( beam_width=100, num_processes=4, blank_id=0, - log_probs_input=False + log_probs_input=False, + is_token_based=False ) beam_results, beam_scores, timesteps, out_lens = decoder.decode(output) ``` @@ -52,6 +63,7 @@ beam_results, beam_scores, timesteps, out_lens = decoder.decode(output) - `num_processes` Parallelize the batch using num_processes workers. You probably want to pass the number of cpus your computer has. You can find this in python with `import multiprocessing` then `n_cpus = multiprocessing.cpu_count()`. Default 4. - `blank_id` This should be the index of the CTC blank token (probably 0). - `log_probs_input` If your outputs have passed through a softmax and represent probabilities, this should be false, if they passed through a LogSoftmax and represent negative log likelihood, you need to pass True. If you don't understand this, run `print(output[0][0].sum())`, if it's a negative number you've probably got NLL and need to pass True, if it sums to ~1.0 you should pass False. Default False. + - `is_token_based` If you use LM based on custom tokens (e.g., BPEs) set to True. Default False. ### Inputs to the `decode` method - `output` should be the output activations from your model. If your output has passed through a SoftMax layer, you shouldn't need to alter it (except maybe to transpose), but if your `output` represents negative log likelihoods (raw logits), you either need to pass it through an additional `torch.nn.functional.softmax` or you can pass `log_probs_input=False` to the decoder. Your output should be BATCHSIZE x N_TIMESTEPS x N_LABELS so you may need to transpose it before passing it to the decoder. Note that if you pass things in the wrong order, the beam search will probably still run, you'll just get back nonsense results. @@ -79,7 +91,8 @@ decoder = OnlineCTCBeamDecoder( beam_width=100, num_processes=4, blank_id=0, - log_probs_input=False + log_probs_input=False, + is_token_based=False ) state1 = ctcdecode.DecoderState(decoder) diff --git a/ctcdecode/__init__.py b/ctcdecode/__init__.py index 1715b778..7d1f197a 100644 --- a/ctcdecode/__init__.py +++ b/ctcdecode/__init__.py @@ -21,6 +21,7 @@ class CTCBeamDecoder(object): num_processes (int): Parallelize the batch using num_processes workers. blank_id (int): Index of the CTC blank token (probably 0) used when training your model. log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1. + is_token_based (bool): True if you use tokens (e.g., BPEs). """ def __init__( @@ -35,6 +36,7 @@ def __init__( num_processes=4, blank_id=0, log_probs_input=False, + is_token_based=False, ): self.cutoff_top_n = cutoff_top_n self._beam_width = beam_width @@ -44,9 +46,10 @@ def __init__( self._num_labels = len(labels) self._blank_id = blank_id self._log_probs = 1 if log_probs_input else 0 + self._is_token_based = 1 if is_token_based else 0 if model_path: self._scorer = ctc_decode.paddle_get_scorer( - alpha, beta, model_path.encode(), self._labels, self._num_labels + alpha, beta, model_path.encode(), self._labels, self._num_labels, self._is_token_based ) self._cutoff_prob = cutoff_prob @@ -124,6 +127,9 @@ def decode(self, probs, seq_lens=None): def character_based(self): return ctc_decode.is_character_based(self._scorer) if self._scorer else None + + def token_based(self): + return ctc_decode.is_token_based(self._scorer) if self._scorer else None def max_order(self): return ctc_decode.get_max_order(self._scorer) if self._scorer else None @@ -158,6 +164,7 @@ class OnlineCTCBeamDecoder(object): num_processes (int): Parallelize the batch using num_processes workers. blank_id (int): Index of the CTC blank token (probably 0) used when training your model. log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1. + is_token_based (bool): True if you use tokens (e.g., BPEs). """ def __init__( self, @@ -171,6 +178,7 @@ def __init__( num_processes=4, blank_id=0, log_probs_input=False, + is_token_based=False, ): self._cutoff_top_n = cutoff_top_n self._beam_width = beam_width @@ -180,9 +188,10 @@ def __init__( self._num_labels = len(labels) self._blank_id = blank_id self._log_probs = 1 if log_probs_input else 0 + self._is_token_based = 1 if is_token_based else 0 if model_path: self._scorer = ctc_decode.paddle_get_scorer( - alpha, beta, model_path.encode(), self._labels, self._num_labels + alpha, beta, model_path.encode(), self._labels, self._num_labels, self._is_token_based ) self._cutoff_prob = cutoff_prob @@ -240,6 +249,9 @@ def decode(self, probs, states, is_eos_s, seq_lens=None): def character_based(self): return ctc_decode.is_character_based(self._scorer) if self._scorer else None + def token_based(self): + return ctc_decode.is_token_based(self._scorer) if self._scorer else None + def max_order(self): return ctc_decode.get_max_order(self._scorer) if self._scorer else None diff --git a/ctcdecode/src/binding.cpp b/ctcdecode/src/binding.cpp index c14a8221..7e772902 100644 --- a/ctcdecode/src/binding.cpp +++ b/ctcdecode/src/binding.cpp @@ -144,8 +144,9 @@ void* paddle_get_scorer(double alpha, double beta, const char* lm_path, vector new_vocab, - int vocab_size) { - Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab); + int vocab_size, + bool is_token_based) { + Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab, is_token_based); return static_cast(scorer); } @@ -272,6 +273,10 @@ int is_character_based(void *scorer){ Scorer *ext_scorer = static_cast(scorer); return ext_scorer->is_character_based(); } +int is_token_based(void *scorer){ + Scorer *ext_scorer = static_cast(scorer); + return ext_scorer->is_token_based(); +} size_t get_max_order(void *scorer){ Scorer *ext_scorer = static_cast(scorer); return ext_scorer->get_max_order(); @@ -293,6 +298,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("paddle_get_scorer", &paddle_get_scorer, "paddle_get_scorer"); m.def("paddle_release_scorer", &paddle_release_scorer, "paddle_release_scorer"); m.def("is_character_based", &is_character_based, "is_character_based"); + m.def("is_token_based", &is_token_based, "is_token_based"); m.def("get_max_order", &get_max_order, "get_max_order"); m.def("get_dict_size", &get_dict_size, "get_max_order"); m.def("reset_params", &reset_params, "reset_params"); diff --git a/ctcdecode/src/binding.h b/ctcdecode/src/binding.h index e6bd2eb1..59ce9162 100644 --- a/ctcdecode/src/binding.h +++ b/ctcdecode/src/binding.h @@ -34,7 +34,8 @@ void* paddle_get_scorer(double alpha, double beta, const char* lm_path, std::vector labels, - int vocab_size); + int vocab_size, + bool is_token_based); void* paddle_get_decoder_state(const std::vector &vocabulary, @@ -50,6 +51,7 @@ void paddle_release_state(void* state); int is_character_based(void *scorer); +int is_token_based(void *scorer); size_t get_max_order(void *scorer); size_t get_dict_size(void *scorer); void reset_params(void *scorer, double alpha, double beta); diff --git a/ctcdecode/src/ctc_beam_search_decoder.cpp b/ctcdecode/src/ctc_beam_search_decoder.cpp index 0d0365d3..a0b65ea8 100644 --- a/ctcdecode/src/ctc_beam_search_decoder.cpp +++ b/ctcdecode/src/ctc_beam_search_decoder.cpp @@ -43,7 +43,7 @@ DecoderState::DecoderState(const std::vector &vocabulary, root.score = root.log_prob_b_prev = 0.0; prefixes.push_back(&root); - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + if (ext_scorer != nullptr && !(ext_scorer->is_character_based() || ext_scorer->is_token_based())) { auto fst_dict = static_cast(ext_scorer->dictionary); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); root.set_dictionary(dict_ptr); @@ -119,10 +119,10 @@ DecoderState::next(const std::vector> &probs_seq) // language model scoring if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { + (c == space_id || ext_scorer->is_character_based() || ext_scorer->is_token_based())) { PathTrie *prefix_to_score = nullptr; // skip scoring the space - if (ext_scorer->is_character_based()) { + if (ext_scorer->is_character_based() || ext_scorer->is_token_based()) { prefix_to_score = prefix_new; } else { prefix_to_score = prefix; @@ -171,7 +171,7 @@ DecoderState::decode() } // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + if (ext_scorer != nullptr && !(ext_scorer->is_character_based() || ext_scorer->is_token_based())) { for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) { auto prefix = prefixes_copy[i]; if (!prefix->is_empty() && prefix->character != space_id) { diff --git a/ctcdecode/src/scorer.cpp b/ctcdecode/src/scorer.cpp index c3550b3a..0136ad6b 100644 --- a/ctcdecode/src/scorer.cpp +++ b/ctcdecode/src/scorer.cpp @@ -16,12 +16,14 @@ using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path, - const std::vector& vocab_list) { + const std::vector& vocab_list, + bool is_token_based) { this->alpha = alpha; this->beta = beta; dictionary = nullptr; - is_character_based_ = true; + is_character_based_ = !is_token_based; + is_token_based_ = is_token_based; language_model_ = nullptr; max_order_ = 0; @@ -47,7 +49,7 @@ void Scorer::setup(const std::string& lm_path, // set char map for scorer set_char_map(vocab_list); // fill the dictionary for FST - if (!is_character_based()) { + if (!(is_character_based() || is_token_based())) { fill_dictionary(true); } } @@ -126,8 +128,10 @@ void Scorer::reset_params(float alpha, float beta) { std::string Scorer::vec2str(const std::vector& input) { std::string word; - for (auto ind : input) { - word += char_list_[ind]; + for (size_t i = 0; i < input.size(); ++i) { + word += char_list_[input[i]]; + if(is_token_based_ && i + 1 < input.size()) + word += " "; } return word; } @@ -135,12 +139,18 @@ std::string Scorer::vec2str(const std::vector& input) { std::vector Scorer::split_labels(const std::vector& labels) { if (labels.empty()) return {}; - std::string s = vec2str(labels); std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); + if(is_token_based_) { + for (auto ind : labels) + words.push_back(char_list_[ind]); + } + else { + std::string s = vec2str(labels); + if (is_character_based_) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } } return words; } @@ -169,7 +179,7 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { std::vector prefix_vec; std::vector prefix_steps; - if (is_character_based_) { + if (is_character_based_ || is_token_based_) { new_node = current_node->get_path_vec(prefix_vec, prefix_steps, -1, 1); current_node = new_node; } else { diff --git a/ctcdecode/src/scorer.h b/ctcdecode/src/scorer.h index 5ebc719c..8245bec3 100644 --- a/ctcdecode/src/scorer.h +++ b/ctcdecode/src/scorer.h @@ -43,7 +43,8 @@ class Scorer { Scorer(double alpha, double beta, const std::string &lm_path, - const std::vector &vocabulary); + const std::vector &vocabulary, + bool is_token_based); ~Scorer(); double get_log_cond_prob(const std::vector &words); @@ -58,6 +59,9 @@ class Scorer { // retrun true if the language model is character based bool is_character_based() const { return is_character_based_; } + + // retrun true if the language model is token based (e.g., BPE) + bool is_token_based() const { return is_token_based_; } // reset params alpha & beta void reset_params(float alpha, float beta); @@ -99,6 +103,7 @@ class Scorer { private: void *language_model_; bool is_character_based_; + bool is_token_based_; size_t max_order_; size_t dict_size_;