Skip to content

Commit

Permalink
Add ents to doc
Browse files Browse the repository at this point in the history
  • Loading branch information
cjer committed Sep 4, 2021
1 parent d444776 commit f8894d2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
95 changes: 70 additions & 25 deletions api_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from io import StringIO
from operator import itemgetter
from itertools import groupby
import iobes

os.environ['CUDA_VISIBLE_DEVICES'] = ''

Expand Down Expand Up @@ -347,6 +348,71 @@ def add_dep_info(docs, md_sents, dep_tree, include_yap_outputs):
doc.dep_tree = dep


def iter_token_attrs(doc, attr):
for token in doc:
yield getattr(token, attr)


def iter_morph_attrs(doc, attr):
for token in doc:
for morph in token:
yield getattr(morph, attr)


NEMO_FIELDS_TOKEN = ['nemo_single', 'nemo_multi_align_token', 'nemo_morph_align_token']
NEMO_FIELDS_MORPH = ['nemo_morph', 'nemo_multi_align_morph']

def to_dict(span, text):
return {
'text': ' '.join(text[span[1]:span[2]]),
'label': span[0],
'start': span[1],
'end': span[2]
}

def get_spans(doc, token_fields=None, morph_fields=None, add_full_text=False):
spans = {}
if morph_fields:
try:
morph_text = list(iter_morph_attrs(doc, 'form'))
except KeyError:
pass
morph_spans = []
for f in morph_fields:
try:
labels = list(iter_morph_attrs(doc, f))
span = {
'scenario': f,
'ents': [to_dict(x, morph_text)
for x in iobes.parse_spans_iobes(labels)]
}
if add_full_text:
span['text'] = morph_text
morph_spans.append(span)
except KeyError:
pass
spans['morph'] = morph_spans
if token_fields:
tok_text = list(iter_token_attrs(doc, 'text'))
tok_spans = []
for f in token_fields:
try:
labels = list(iter_token_attrs(doc, f))
span = {
'scenario': f,
'ents': [to_dict(x, tok_text)
for x in iobes.parse_spans_iobes(labels)]
}
if add_full_text:
span['text'] = tok_text
tok_spans.append(span)
except KeyError:
pass
spans['token'] = tok_spans

return spans


description = """
NEMO API helps you do awesome stuff with Hebrew named entities and morphology 🐠
Expand Down Expand Up @@ -388,17 +454,6 @@ def add_dep_info(docs, md_sents, dep_tree, include_yap_outputs):


#query objects for FastAPI documentation
# sent_query = Query( None,
# description="Hebrew sentences seprated by '\\n'",
# example="עשרות אנשים מגיעים מתאילנד לישראל.\nתופעה זו התבררה אתמול בוועדת העבודה והרווחה של הכנסת.",
# )


# tokenized_query = Query( False,
# description="Are sentences pre-tokenized? If so, we split each sentence by space char. Else, we use a built in tokenizer."
# )


multi_model_query = Query(MultiModelName.token_multi,
description="Name of an available toke-multi model.",
)
Expand Down Expand Up @@ -489,7 +544,9 @@ def multi_to_single(
doc_set_token_attr(docs, 'nemo_multi', ner_multi_preds)

doc_add_multi_align_tok(docs, ner_multi_preds)


for doc in docs:
doc.ents = get_spans(doc, token_fields=['nemo_multi_align_token'], morph_fields=[])
return docs


Expand Down Expand Up @@ -651,16 +708,4 @@ def morph_hybrid_align_tokens(q: NEMOQuery,
include_yap_outputs: Optional[bool]=False):
return morph_hybrid(q, multi_model_name, morph_model_name, align_tokens=True,
verbose=verbose, include_yap_outputs=include_yap_outputs)


#
# @app.post("/run_separate_nemo/")
# def run_separate_nemo(command: str, model_name: str, sentence: str):
# if command in available_commands:
# if command == 'run_ner_model':
# with Temp('r', encoding='utf8') as temp_output:
# nemo.run_ner_model(model_name, None, temp_output.name, text_input=sentence)
# output_text = temp_output.read()
# return { 'nemo_output': output_text }
# else:
# return {'error': 'command not supported'}

7 changes: 3 additions & 4 deletions schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __next__(self):

class Doc(BaseModel):
text: Optional[str] = None
ents: Optional[dict] = None
tokens: List[Token]
#morphs: Optional[List[Morpheme]] = [] # better? and add token_id to Morpheme
ma_lattice: Optional[str] = None
Expand All @@ -89,13 +90,11 @@ def __iter__(self):
def __next__(self):
return self.tokens.__next__()

@classmethod
def iter_token_attrs(self, attr):
for i, token in enumerate(self):
for i, token in enumerate(self.tokens):
yield i, getattr(token, attr)

@classmethod
def iter_morph_attrs(self, attr):
for i, token in enumerate(self):
for i, token in enumerate(self.tokens):
for morph in token:
yield i, getattr(morph, attr)

0 comments on commit f8894d2

Please sign in to comment.