diff --git a/src/text_to_multi_option_benchmark.py b/src/text_to_multi_option_benchmark.py index 59f1dc4..d80c6f1 100644 --- a/src/text_to_multi_option_benchmark.py +++ b/src/text_to_multi_option_benchmark.py @@ -20,7 +20,9 @@ TextToMultiOptionExtractor, ) -LABELED_DATA_PATH = join(Path(__file__).parent, "extractors", "text_to_multi_option_extractor", "labeled_data") +LABELED_DATA_PATH = join( + Path(__file__).parent, "trainable_entity_extractor", "extractors", "text_to_multi_option_extractor", "labeled_data" +) def get_extraction_data(filter_by: list[str] = None): @@ -38,7 +40,10 @@ def get_extraction_data(filter_by: list[str] = None): for i, text_value in enumerate(texts_values): values = [Option(id=x, label=x) for x in text_value["values"]] language_iso = "es" if "cejil" in task_name else "en" - labeled_data = LabeledData(values=values, entity_name=str(i), language_iso=language_iso) + labeled_data = LabeledData( + values=values, entity_name=str(i), language_iso=language_iso, source_text=text_value["text"] + ) + extraction_sample = TrainingSample(segment_selector_texts=[text_value["text"]], labeled_data=labeled_data) samples.append(extraction_sample) @@ -72,14 +77,15 @@ def get_benchmark(): for extraction_data in extractions_data: start = time() extractor = TextToMultiOptionExtractor(extraction_identifier=extraction_data.extraction_identifier) - train_set, test_set = ExtractorBase.get_train_test_sets(extraction_data, limit_samples=False) + train_set, test_set = ExtractorBase.get_train_test_sets(extraction_data) values_list = [x.labeled_data.values for x in test_set.samples] truth_one_hot = PdfMultiOptionMethod.one_hot_to_options_list(values_list, extraction_data.options) extractor.create_model(train_set) tags_texts = [x.segment_selector_texts for x in test_set.samples] test_data = [ - PredictionSample(segment_selector_texts=tag_text, entity_name=str(i)) for i, tag_text in enumerate(tags_texts) + PredictionSample(segment_selector_texts=tag_text, entity_name=str(i), source_text=" ".join(tag_text)) + for i, tag_text in enumerate(tags_texts) ] suggestions = extractor.get_suggestions(test_data) values_list = [x.values for x in suggestions] diff --git a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt index fd32974..cab7ab8 100644 --- a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt +++ b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt @@ -11,7 +11,7 @@ Extractor cejil_countries 5.5 94.12% Extractor cyrilla_keywords [TextSetFit] 4.0 68.29% Extractor cejil_judge 0.0 100.0% Extractor action [TextSingleLabelSetFit] 10.1 88.3% -Extractor issues [Setfit] 13.0 76.7% +Extractor issues [Setfit] 27.7 78.3% Extractor issues [BertBase] - 70.8% Extractor issues [BertLarge] 98.3 90.2% Extractor themes [Setfit] 5.3 66.1%