Skip to content

Commit

Permalink
handle empty sentences input
Browse files Browse the repository at this point in the history
  • Loading branch information
cjer committed Aug 29, 2021
1 parent 5b92c5c commit 60c69f1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 11 additions & 1 deletion api_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ncrf_decode(model, data, temp_input):
data.raw_dir = temp_input
#data.decode_dir = temp_output
data.generate_instance('raw')
_, _, _, _, _, preds, _ = evaluate(data, model, 'raw', data.nbest)
_, _, _, _, _, preds, _ = evaluate(data, model, 'raw', data.nbest, calc_fmeasure=False)
if data.nbest==1:
preds = [sent[0] for sent in preds]
return preds
Expand Down Expand Up @@ -329,6 +329,8 @@ def load_all_models():
def run_ner_model(q: NEMOQuery,
model_name: Optional[ModelName]=ModelName.token_single,
):
if not q.sentences.strip():
return []
model = loaded_models[model_name]
temp_input = temporary_filename()
tok_sents = create_input_file(q.sentences, temp_input, q.tokenized)
Expand All @@ -346,6 +348,8 @@ def run_ner_model(q: NEMOQuery,
def multi_to_single(q: NEMOQuery,
multi_model_name: Optional[MultiModelName]=multi_model_query,
):
if not q.sentences.strip():
return []
model_out = run_ner_model(q, multi_model_name)
tok_sents, ner_multi_preds = zip(*[(x.tokenized_text, x.ncrf_preds) for x in model_out])
ner_single_preds = [[fix_multi_biose(label) for label in sent] for sent in ner_multi_preds]
Expand All @@ -366,6 +370,8 @@ def multi_to_single(q: NEMOQuery,
def multi_align_hybrid(q: NEMOQuery,
multi_model_name: Optional[MultiModelName]=multi_model_query,
include_dep_tree: Optional[bool]=False):
if not q.sentences.strip():
return []
model_out = run_ner_model(q, multi_model_name)
tok_sents, ner_multi_preds = zip(*[(x.tokenized_text, x.ncrf_preds) for x in model_out])
ner_single_preds = [[fix_multi_biose(label) for label in sent] for sent in ner_multi_preds]
Expand Down Expand Up @@ -404,6 +410,8 @@ def multi_align_hybrid(q: NEMOQuery,
def morph_yap(q: NEMOQuery,
morph_model_name: Optional[MorphModelName]=morph_model_query,
):
if not q.sentences.strip():
return []
tok_sents = get_sents(q.sentences, q.tokenized)
yap_out = run_yap_joint(tok_sents)
md_sents = (bclm.get_sentences_list(bclm.read_lattices(StringIO(yap_out['md_lattice'])), ['form']).apply(lambda x: [t[0] for t in x] )).to_list()
Expand Down Expand Up @@ -437,6 +445,8 @@ def morph_hybrid(q: NEMOQuery,
morph_model_name: Optional[MorphModelName]=morph_model_query,
align_tokens: Optional[bool] = False,
include_dep_tree: Optional[bool]=False):
if not q.sentences.strip():
return []
model_out = run_ner_model(q, multi_model_name)
tok_sents, ner_multi_preds = zip(*[(x.tokenized_text, x.ncrf_preds) for x in model_out])
ner_single_preds = [[fix_multi_biose(label) for label in sent] for sent in ner_multi_preds]
Expand Down
12 changes: 9 additions & 3 deletions ncrf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def lr_decay(optimizer, epoch, decay_rate, init_lr):



def evaluate(data, model, name, nbest=None):
def evaluate(data, model, name, nbest=None, calc_fmeasure=True):
if name == "train":
instances = data.train_Ids
elif name == "dev":
Expand Down Expand Up @@ -179,8 +179,14 @@ def evaluate(data, model, name, nbest=None):
pred_results += pred_label
gold_results += gold_label
decode_time = time.time() - start_time
speed = len(instances)/decode_time
acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme, verbose=False)
if decode_time:
speed = len(instances)/decode_time
else:
speed = -1
if calc_fmeasure:
acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme, verbose=False)
else:
acc, p, r, f = (1, 1, 1, 1)
if nbest and not data.sentence_classification:
return speed, acc, p, r, f, nbest_pred_results, pred_scores
return speed, acc, p, r, f, pred_results, pred_scores
Expand Down

0 comments on commit 60c69f1

Please sign in to comment.