From e3b461e3f6d63f16c6ee8325d80d762b038c6437 Mon Sep 17 00:00:00 2001 From: Florian Borchert Date: Wed, 6 Mar 2024 20:50:16 +0100 Subject: [PATCH] predict_no_context (#31) --- pyproject.toml | 2 +- xmen/linkers/__init__.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3acae2a..7ef04d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xmen" -version = "1.0.4" +version = "1.0.5" description = "An extensible toolkit for Cross-lingual (x) Medical Entity Normalization." license = "Apache-2.0" authors = ["Florian Borchert "] diff --git a/xmen/linkers/__init__.py b/xmen/linkers/__init__.py index db03aa2..74e0255 100644 --- a/xmen/linkers/__init__.py +++ b/xmen/linkers/__init__.py @@ -6,6 +6,7 @@ from xmen.log import logger from xmen.reranking import Reranker +from xmen.data import from_spans class EntityLinker(ABC): @@ -31,6 +32,42 @@ def get_logger(self): def predict(self, passages: list, entities: list) -> list: pass + def predict_no_context( + self, entities: str | list[str], label: str | list[str] = None, batch_size: int = None + ) -> list: + """ + Generates candidate concepts for the given entities (one or more) without any context. + + Args: + - entities (str | list[str]): The entity or entities for which to generate candidates. + - label (str | list[str]): The label or labels for the entities. If a single label is provided, it will be used for all entities. + - batch_size (int): The batch size to use for prediction. If None, the default batch size of the model will be used. + """ + is_str = False + if isinstance(entities, str): + is_str = True + entities = [entities] + assert label is None or isinstance(label, str) + label = [label] + elif label is None or isinstance(label, str): + label = [label] * len(entities) + assert len(entities) == len(label) + + spans = [] + sentences = [] + indices = [] + for e, l in zip(entities, label): + indices.append(len(sentences)) + spans.append([{"char_start_index": 0, "char_end_index": len(e), "label": l, "span": e}]) + sentences.append(e) + ds = from_spans(entities=spans, sentences=sentences) + result = self.predict_batch(ds, batch_size) + if is_str: + assert len(result["entities"]) == 1 + return result["entities"][0] + else: + return result["entities"] + class RerankedLinker(EntityLinker): def __init__(self, linker: EntityLinker, ranker: Reranker):