diff --git a/mltb2/fasttext.py b/mltb2/fasttext.py index 8d2dda7..ee3c65b 100644 --- a/mltb2/fasttext.py +++ b/mltb2/fasttext.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Philip May +# Copyright (c) 2023-2024 Philip May # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT @@ -11,6 +11,7 @@ import os from dataclasses import dataclass, field +from typing import List, Optional import fasttext from fasttext.FastText import _FastText @@ -51,12 +52,15 @@ def get_model_path_and_download() -> str: return model_full_path - def __call__(self, text: str, num_lang: int = 10): + def __call__(self, text: str, num_lang: int = 10, always_detect_lang: Optional[List[str]] = None): """Identify languages of a given text. Args: text: the text for which the language is to be recognized - num_lang: number of returned languages + num_lang: number of returned language probabilities + always_detect_lang: A list of languages that should always be returned + even if not detected. If the language is not detected, the probability + is set to 0.0. Returns: A dict from language to probability. This dict contains no more than ``num_lang`` elements. @@ -76,4 +80,8 @@ def __call__(self, text: str, num_lang: int = 10): languages = predictions[0] probabilities = predictions[1] lang_to_prob = {lang[9:]: prob for lang, prob in zip(languages, probabilities)} + if always_detect_lang is not None: + for lang in always_detect_lang: + if lang not in lang_to_prob: + lang_to_prob[lang] = 0.0 return lang_to_prob diff --git a/tests/test_fasttext.py b/tests/test_fasttext.py index 9a1c2f8..30fc787 100644 --- a/tests/test_fasttext.py +++ b/tests/test_fasttext.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Philip May +# Copyright (c) 2023-2024 Philip May # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT @@ -16,3 +16,11 @@ def test_fasttext_language_identification_call(): languages = language_identification("This is an English sentence.") assert languages is not None assert len(languages) == 10 + + +def test_fasttext_language_identification_call_with_always_detect_lang(): + language_identification = FastTextLanguageIdentification() + languages = language_identification("This is an English sentence.", always_detect_lang=["fake_language"]) + assert languages is not None + assert len(languages) == 11 + assert "fake_language" in languages