Skip to content

Commit

Permalink
Fix text to multi option benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Nov 29, 2024
1 parent 944f843 commit 923931e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions src/text_to_multi_option_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
TextToMultiOptionExtractor,
)

LABELED_DATA_PATH = join(Path(__file__).parent, "extractors", "text_to_multi_option_extractor", "labeled_data")
LABELED_DATA_PATH = join(
Path(__file__).parent, "trainable_entity_extractor", "extractors", "text_to_multi_option_extractor", "labeled_data"
)


def get_extraction_data(filter_by: list[str] = None):
Expand All @@ -38,7 +40,10 @@ def get_extraction_data(filter_by: list[str] = None):
for i, text_value in enumerate(texts_values):
values = [Option(id=x, label=x) for x in text_value["values"]]
language_iso = "es" if "cejil" in task_name else "en"
labeled_data = LabeledData(values=values, entity_name=str(i), language_iso=language_iso)
labeled_data = LabeledData(
values=values, entity_name=str(i), language_iso=language_iso, source_text=text_value["text"]
)

extraction_sample = TrainingSample(segment_selector_texts=[text_value["text"]], labeled_data=labeled_data)
samples.append(extraction_sample)

Expand Down Expand Up @@ -72,14 +77,15 @@ def get_benchmark():
for extraction_data in extractions_data:
start = time()
extractor = TextToMultiOptionExtractor(extraction_identifier=extraction_data.extraction_identifier)
train_set, test_set = ExtractorBase.get_train_test_sets(extraction_data, limit_samples=False)
train_set, test_set = ExtractorBase.get_train_test_sets(extraction_data)
values_list = [x.labeled_data.values for x in test_set.samples]
truth_one_hot = PdfMultiOptionMethod.one_hot_to_options_list(values_list, extraction_data.options)
extractor.create_model(train_set)

tags_texts = [x.segment_selector_texts for x in test_set.samples]
test_data = [
PredictionSample(segment_selector_texts=tag_text, entity_name=str(i)) for i, tag_text in enumerate(tags_texts)
PredictionSample(segment_selector_texts=tag_text, entity_name=str(i), source_text=" ".join(tag_text))
for i, tag_text in enumerate(tags_texts)
]
suggestions = extractor.get_suggestions(test_data)
values_list = [x.values for x in suggestions]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Extractor cejil_countries 5.5 94.12%
Extractor cyrilla_keywords [TextSetFit] 4.0 68.29%
Extractor cejil_judge 0.0 100.0%
Extractor action [TextSingleLabelSetFit] 10.1 88.3%
Extractor issues [Setfit] 13.0 76.7%
Extractor issues [Setfit] 27.7 78.3%
Extractor issues [BertBase] - 70.8%
Extractor issues [BertLarge] 98.3 90.2%
Extractor themes [Setfit] 5.3 66.1%
Expand Down

0 comments on commit 923931e

Please sign in to comment.