Skip to content

Commit

Permalink
🐛 Fix bug in evaluation with both mentions extractors and linkers (#34)
Browse files Browse the repository at this point in the history
* 🐛 Fix bug in evaluation with both mentions extractors and linkers
* 🎨 Fix style

Signed-off-by: Marcos Martinez <[email protected]>
  • Loading branch information
marmg authored Nov 15, 2022
1 parent 3ae7066 commit eebbbb1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 84 deletions.
14 changes: 3 additions & 11 deletions zshot/evaluation/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import argparse

import spacy

from zshot import PipelineConfig
from zshot.evaluation import load_medmentions, load_ontonotes
from zshot.evaluation.metrics.seqeval.seqeval import Seqeval
from zshot.evaluation.zshot_evaluate import evaluate
from zshot.linker import LinkerTARS, LinkerSMXM, LinkerRegen
Expand All @@ -26,7 +24,8 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="ontonotes", type=str, help="Name or path to the validation data. Comma separated")
parser.add_argument("--dataset", default="ontonotes", type=str,
help="Name or path to the validation data. Comma separated")
parser.add_argument("--splits", required=False, default="train, test, validation", type=str,
help="Splits to evaluate. Comma separated")
parser.add_argument("--mode", required=False, default="full", type=str,
Expand Down Expand Up @@ -62,8 +61,6 @@
)
else:
configs[linker] = PipelineConfig(linker=LINKERS[linker]())
for mentions_extractor in mentions_extractors:
configs[mentions_extractor] = PipelineConfig(mentions_extractor=MENTION_EXTRACTORS[mentions_extractor]())
elif args.mode == "mentions_extractor":
for mentions_extractor in mentions_extractors:
configs[mentions_extractor] = PipelineConfig(mentions_extractor=MENTION_EXTRACTORS[mentions_extractor]())
Expand All @@ -81,9 +78,4 @@
nlp = spacy.blank("en") if "spacy" not in key else spacy.load("en_core_web_sm")
nlp.add_pipe("zshot", config=config, last=True)

if args.dataset.lower() == "medmentions":
dataset = load_medmentions()
else:
dataset = load_ontonotes()

print(evaluate(nlp, dataset, splits=args.splits, metric=Seqeval()))
print(evaluate(nlp, args.dataset, splits=args.splits, metric=Seqeval()))
121 changes: 48 additions & 73 deletions zshot/evaluation/zshot_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from typing import Optional, List, Union

import spacy
from datasets import Dataset, NamedSplit
from evaluate import EvaluationModule
from prettytable import PrettyTable

from zshot.evaluation.evaluator import (
ZeroShotTokenClassificationEvaluator,
MentionsExtractorEvaluator,
)
from zshot.evaluation import load_medmentions, load_ontonotes
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline


def evaluate(
nlp: spacy.language.Language,
datasets: Union[Dataset, List[Dataset]],
splits: Optional[Union[NamedSplit, List[NamedSplit]]] = None,
metric: Optional[Union[str, EvaluationModule]] = None,
batch_size: Optional[int] = 16,
) -> str:
"""Evaluate a spacy zshot model
def evaluate(nlp: spacy.language.Language,
datasets: Union[str, List[str]],
splits: Optional[Union[str, List[str]]] = None,
metric: Optional[Union[str, EvaluationModule]] = None,
batch_size: Optional[int] = 16) -> str:
""" Evaluate a spacy zshot model
:param nlp: Spacy Language pipeline with ZShot components
:param datasets: Dataset or list of datasets to evaluate
Expand All @@ -32,42 +26,52 @@ def evaluate(
linker_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
mentions_extractor_evaluator = MentionsExtractorEvaluator("token-classification")

if not isinstance(splits, list):
if type(splits) == str:
splits = [splits]

if not isinstance(datasets, list):
if type(datasets) == str:
datasets = [datasets]

result = {}
field_names = ["Metric"]
for dataset in datasets:
for dataset_name in datasets:
if dataset_name.lower() == "medmentions":
dataset = load_medmentions()
else:
dataset = load_ontonotes()

for split in splits:
field_name = f"{dataset.description} {split}"
field_name = f"{dataset_name} {split}"
field_names.append(field_name)
nlp.get_pipe("zshot").mentions = dataset[split].entities
nlp.get_pipe("zshot").entities = dataset[split].entities
if nlp.get_pipe("zshot").linker:
pipe = LinkerPipeline(nlp, batch_size)
result.update(
{
field_name: {
"linker": linker_evaluator.compute(
pipe, dataset[split], metric=metric
)
res_tmp = {
'linker': linker_evaluator.compute(pipe, dataset[split], metric=metric)
}
if field_name not in result:
result.update(
{
field_name: res_tmp
}
}
)
)
else:
result[field_name].update(res_tmp)
if nlp.get_pipe("zshot").mentions_extractor:
pipe = MentionsExtractorPipeline(nlp, batch_size)
result.update(
{
field_name: {
"mentions_extractor": mentions_extractor_evaluator.compute(
pipe, dataset[split], metric=metric
)
res_tmp = {
'mentions_extractor': mentions_extractor_evaluator.compute(pipe, dataset[split],
metric=metric)
}
if field_name not in result:
result.update(
{
field_name: res_tmp
}
}
)
)
else:
result[field_name].update(res_tmp)

table = PrettyTable()
table.field_names = field_names
Expand All @@ -81,25 +85,11 @@ def evaluate(
for field_name in field_names:
if field_name == "Metric":
continue
linker_precisions.append(
"{:.2f}%".format(
result[field_name]["linker"]["overall_precision_macro"] * 100
)
)
linker_recalls.append(
"{:.2f}%".format(
result[field_name]["linker"]["overall_recall_macro"] * 100
)
)
linker_accuracies.append(
"{:.2f}%".format(result[field_name]["linker"]["overall_accuracy"] * 100)
)
linker_micros.append(
"{:.2f}%".format(result[field_name]["linker"]["overall_f1_micro"] * 100)
)
linker_macros.append(
"{:.2f}%".format(result[field_name]["linker"]["overall_f1_macro"] * 100)
)
linker_precisions.append("{:.2f}%".format(result[field_name]['linker']['overall_precision_macro'] * 100))
linker_recalls.append("{:.2f}%".format(result[field_name]['linker']['overall_recall_macro'] * 100))
linker_accuracies.append("{:.2f}%".format(result[field_name]['linker']['overall_accuracy'] * 100))
linker_micros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_micro'] * 100))
linker_macros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_macro'] * 100))

rows.append(["Linker Precision"] + linker_precisions)
rows.append(["Linker Recall"] + linker_recalls)
Expand All @@ -117,30 +107,15 @@ def evaluate(
if field_name == "Metric":
continue
mentions_extractor_precisions.append(
"{:.2f}%".format(
result[field_name]["mentions_extractor"]["overall_precision_macro"] * 100
)
)
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_precision_macro'] * 100))
mentions_extractor_recalls.append(
"{:.2f}%".format(
result[field_name]["mentions_extractor"]["overall_recall_macro"] * 100
)
)
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_recall_macro'] * 100))
mentions_extractor_accuracies.append(
"{:.2f}%".format(
result[field_name]["mentions_extractor"]["overall_accuracy"] * 100
)
)
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_accuracy'] * 100))
mentions_extractor_micros.append(
"{:.2f}%".format(
result[field_name]["mentions_extractor"]["overall_f1_micro"] * 100
)
)
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_micro'] * 100))
mentions_extractor_macros.append(
"{:.2f}%".format(
result[field_name]["mentions_extractor"]["overall_f1_macro"] * 100
)
)
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_macro'] * 100))

rows.append(["Mentions extractor Precision"] + mentions_extractor_precisions)
rows.append(["Mentions extractor Recall"] + mentions_extractor_recalls)
Expand Down

0 comments on commit eebbbb1

Please sign in to comment.