diff --git a/src/text_to_multi_option_benchmark.py b/src/text_to_multi_option_benchmark.py index d80c6f1..a0e4657 100644 --- a/src/text_to_multi_option_benchmark.py +++ b/src/text_to_multi_option_benchmark.py @@ -5,7 +5,9 @@ from time import time import rich +from fastfit import FastFit from sklearn.metrics import f1_score +from transformers import AutoTokenizer, pipeline from trainable_entity_extractor.data.ExtractionData import ExtractionData from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier @@ -73,7 +75,7 @@ def get_benchmark(): # action # themes # issues - extractions_data: list[ExtractionData] = get_extraction_data(filter_by=["issues"]) + extractions_data: list[ExtractionData] = get_extraction_data(filter_by=["action"]) for extraction_data in extractions_data: start = time() extractor = TextToMultiOptionExtractor(extraction_identifier=extraction_data.extraction_identifier) @@ -131,3 +133,7 @@ def check_results(): if __name__ == "__main__": get_benchmark() # check_results() + # model = FastFit.from_pretrained("/home/gabo/ssd/projects/trainable-entity-extractor/models_data/text_benchmark/action/TextFastFit/fast_fit_model") + # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") + # classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, trust_remote_code=True) + # print('oh!') diff --git a/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/PdfToTextExtractor.py b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/PdfToTextExtractor.py index 0c64e8d..4e28fc3 100644 --- a/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/PdfToTextExtractor.py +++ b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/PdfToTextExtractor.py @@ -11,6 +11,9 @@ from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextFastSegmentSelector import ( PdfToTextFastSegmentSelector, ) +from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextNearFastSegmentSelector import ( + PdfToTextNearFastSegmentSelector, +) from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextRegexMethod import PdfToTextRegexMethod from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextSegmentSelector import ( PdfToTextSegmentSelector, @@ -42,15 +45,22 @@ class PdfToTextExtractor(ToTextExtractor): fast_segment_selector_methods = [ pdf_to_text_method_builder(PdfToTextFastSegmentSelector, x) for x in text_to_text_methods ] + + near_fast_segment_selector_methods = [ + pdf_to_text_method_builder(PdfToTextNearFastSegmentSelector, x) for x in text_to_text_methods + ] segment_selector_methods = [pdf_to_text_method_builder(PdfToTextSegmentSelector, x) for x in text_to_text_methods] t5_methods = [ pdf_to_text_method_builder(PdfToTextFastSegmentSelector, MT5TrueCaseEnglishSpanishMethod), pdf_to_text_method_builder(PdfToTextSegmentSelector, MT5TrueCaseEnglishSpanishMethod), ] - METHODS: list[type[ToTextExtractorMethod]] = ( - stand_alone_methods + fast_segment_selector_methods + segment_selector_methods + t5_methods - ) + METHODS: list[type[ToTextExtractorMethod]] = list() + METHODS += stand_alone_methods + METHODS += fast_segment_selector_methods + METHODS += near_fast_segment_selector_methods + METHODS += segment_selector_methods + METHODS += t5_methods def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]: if not extraction_data or not extraction_data.samples: @@ -65,32 +75,34 @@ def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]: def get_train_test_sets(extraction_data: ExtractionData) -> (ExtractionData, ExtractionData): samples_with_label_segments_boxes = [x for x in extraction_data.samples if x.labeled_data.label_segments_boxes] - if len(samples_with_label_segments_boxes) < 2 and len(extraction_data.samples) > 10: + if len(extraction_data.samples) <= 10: + return extraction_data, extraction_data + + if len(samples_with_label_segments_boxes) < 2: return PdfToTextExtractor.split_80_20(extraction_data) if len(samples_with_label_segments_boxes) < 10: - test_extraction_data = ExtractorBase.get_extraction_data_from_samples( + train_extraction_data = ExtractorBase.get_extraction_data_from_samples( extraction_data, samples_with_label_segments_boxes ) - return extraction_data, test_extraction_data + return train_extraction_data, extraction_data samples_without_label_segments_boxes = [ x for x in extraction_data.samples if not x.labeled_data.label_segments_boxes ] - train_size = int(len(samples_with_label_segments_boxes) * 0.8) - train_set: list[TrainingSample] = ( - samples_with_label_segments_boxes[:train_size] + samples_without_label_segments_boxes - ) + train_size = int(len(extraction_data.samples) * 0.7) - if len(extraction_data.samples) < 15: - test_set: list[TrainingSample] = samples_with_label_segments_boxes[-10:] - else: - test_set = samples_with_label_segments_boxes[train_size:] + if len(samples_with_label_segments_boxes) >= train_size: + train_extraction_data = ExtractorBase.get_extraction_data_from_samples( + extraction_data, samples_with_label_segments_boxes[:train_size] + ) + test_extraction_data = ExtractorBase.get_extraction_data_from_samples( + extraction_data, samples_with_label_segments_boxes[train_size:] + samples_without_label_segments_boxes + ) + return train_extraction_data, test_extraction_data - train_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, train_set) - test_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, test_set) - return train_extraction_data, test_extraction_data + return PdfToTextExtractor.split_80_20(extraction_data) @staticmethod def split_80_20(extraction_data): diff --git a/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/methods/PdfToTextNearFastSegmentSelector.py b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/methods/PdfToTextNearFastSegmentSelector.py new file mode 100644 index 0000000..ddcaa27 --- /dev/null +++ b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/methods/PdfToTextNearFastSegmentSelector.py @@ -0,0 +1,41 @@ +from trainable_entity_extractor.data.PdfDataSegment import PdfDataSegment +from trainable_entity_extractor.data.PredictionSample import PredictionSample +from trainable_entity_extractor.extractors.ToTextExtractorMethod import ToTextExtractorMethod +from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextFastSegmentSelector import ( + PdfToTextFastSegmentSelector, +) +from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextSegmentSelector import ( + PdfToTextSegmentSelector, +) + +from trainable_entity_extractor.extractors.segment_selector.FastAndPositionsSegmentSelector import ( + FastAndPositionsSegmentSelector, +) +from trainable_entity_extractor.extractors.segment_selector.NearFastSegmentSelector import NearFastSegmentSelector + + +class PdfToTextNearFastSegmentSelector(PdfToTextFastSegmentSelector): + + def create_segment_selector_model(self, extraction_data): + segments = list() + + for sample in extraction_data.samples: + segments.extend(sample.pdf_data.pdf_data_segments) + + fast_segment_selector = NearFastSegmentSelector(self.extraction_identifier) + fast_segment_selector.create_model(segments=segments) + return True, "" + + def predict(self, predictions_samples: list[PredictionSample]) -> list[str]: + if not predictions_samples: + return [""] * len(predictions_samples) + + fast_segment_selector = NearFastSegmentSelector(self.extraction_identifier) + + for sample in predictions_samples: + selected_segments = fast_segment_selector.predict(sample.pdf_data.pdf_data_segments) + self.mark_predicted_segments(selected_segments) + sample.segment_selector_texts = self.get_predicted_texts(sample.pdf_data) + + semantic_metadata_extraction = self.SEMANTIC_METHOD(self.extraction_identifier, self.get_name()) + return semantic_metadata_extraction.predict(predictions_samples) diff --git a/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/test/test_pdf_to_text_extractor.py b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/test/test_pdf_to_text_extractor.py index 7f912c5..6680a12 100644 --- a/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/test/test_pdf_to_text_extractor.py +++ b/src/trainable_entity_extractor/extractors/pdf_to_text_extractor/test/test_pdf_to_text_extractor.py @@ -42,8 +42,8 @@ def test_get_train_test_with_few_samples(self): self.assertEqual(train_set.extraction_identifier, self.extraction_identifier) self.assertEqual(test_set.extraction_identifier, self.extraction_identifier) - self.assertEqual(len(train_set.samples), 9) - self.assertEqual(len(test_set.samples), 4) + self.assertEqual(9, len(train_set.samples)) + self.assertEqual(9, len(test_set.samples)) def test_get_train_test_without_enough_labeled_segments(self): pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier) @@ -55,8 +55,8 @@ def test_get_train_test_without_enough_labeled_segments(self): train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data) - self.assertEqual(len(train_set.samples), 20) - self.assertEqual(len(test_set.samples), 9) + self.assertEqual(9, len(train_set.samples)) + self.assertEqual(20, len(test_set.samples)) def test_get_train_test_without_labeled_segments(self): pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier) @@ -68,8 +68,8 @@ def test_get_train_test_without_labeled_segments(self): train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data) - self.assertEqual(len(train_set.samples), 80) - self.assertEqual(len(test_set.samples), 20) + self.assertEqual(80, len(train_set.samples)) + self.assertEqual(20, len(test_set.samples)) def test_get_train_test_without_enough_data(self): pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier) @@ -81,8 +81,8 @@ def test_get_train_test_without_enough_data(self): train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data) - self.assertEqual(len(train_set.samples), 180) - self.assertEqual(len(test_set.samples), 20) + self.assertEqual(160, len(train_set.samples)) + self.assertEqual(40, len(test_set.samples)) def test_get_train_test_only_labels_with_segments(self): pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier) @@ -94,5 +94,5 @@ def test_get_train_test_only_labels_with_segments(self): train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data) - self.assertEqual(len(train_set.samples), 160) - self.assertEqual(len(test_set.samples), 40) + self.assertEqual(140, len(train_set.samples)) + self.assertEqual(60, len(test_set.samples)) diff --git a/src/trainable_entity_extractor/extractors/segment_selector/FastSegmentSelector.py b/src/trainable_entity_extractor/extractors/segment_selector/FastSegmentSelector.py index 4c2106b..d5eb91b 100644 --- a/src/trainable_entity_extractor/extractors/segment_selector/FastSegmentSelector.py +++ b/src/trainable_entity_extractor/extractors/segment_selector/FastSegmentSelector.py @@ -152,7 +152,11 @@ def predict(self, segments): model = lgb.Booster(model_file=self.model_path) predictions = model.predict(x) - return [segment for i, segment in enumerate(segments) if predictions[i] > 0.5] + return self.predictions_scores_to_segments(segments, predictions) + + @staticmethod + def predictions_scores_to_segments(segments: list[PdfDataSegment], prediction_scores: list[float]): + return [segment for i, segment in enumerate(segments) if prediction_scores[i] > 0.5] def load_repeated_words(self): self.previous_words = [] diff --git a/src/trainable_entity_extractor/extractors/segment_selector/NearFastSegmentSelector.py b/src/trainable_entity_extractor/extractors/segment_selector/NearFastSegmentSelector.py new file mode 100644 index 0000000..0ec1eac --- /dev/null +++ b/src/trainable_entity_extractor/extractors/segment_selector/NearFastSegmentSelector.py @@ -0,0 +1,23 @@ +from trainable_entity_extractor.data.PdfDataSegment import PdfDataSegment +from trainable_entity_extractor.extractors.segment_selector.FastAndPositionsSegmentSelector import ( + FastAndPositionsSegmentSelector, +) + + +class NearFastSegmentSelector(FastAndPositionsSegmentSelector): + @staticmethod + def predictions_scores_to_segments(segments: list[PdfDataSegment], prediction_scores: list[float]): + predicted_segments = [] + for i, (segment, prediction) in enumerate(zip(segments, prediction_scores)): + if prediction > 0.5: + predicted_segments.append(segment) + continue + + if len(prediction_scores) >= i + 1 and prediction_scores[i + 1] > 0.5: + predicted_segments.append(segment) + continue + + if i != 0 and prediction_scores[i - 1] > 0.5: + predicted_segments.append(segment) + + return predicted_segments diff --git a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/TextToMultiOptionExtractor.py b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/TextToMultiOptionExtractor.py index 8dc9763..2bd96f0 100644 --- a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/TextToMultiOptionExtractor.py +++ b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/TextToMultiOptionExtractor.py @@ -1,7 +1,4 @@ -import json import os -from os.path import join, exists -from pathlib import Path from typing import Type from trainable_entity_extractor.config import config_logger @@ -41,7 +38,6 @@ from trainable_entity_extractor.extractors.text_to_multi_option_extractor.methods.TextSingleLabelSetFitMultilingual import ( TextSingleLabelSetFitMultilingual, ) -from trainable_entity_extractor.extractors.text_to_multi_option_extractor.methods.TextTfIdf import TextTfIdf class TextToMultiOptionExtractor(ExtractorBase): @@ -139,6 +135,7 @@ def get_performance(extraction_data, method_instance): try: performance = method_instance.performance(extraction_data) except: + config_logger.info("ERROR", exc_info=True) performance = 0 config_logger.info(f"\nPerformance {method_instance.get_name()}: {performance}%") return performance diff --git a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/methods/TextFastFit.py b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/methods/TextFastFit.py new file mode 100644 index 0000000..7e65d03 --- /dev/null +++ b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/methods/TextFastFit.py @@ -0,0 +1,126 @@ +import os +import shutil +from os.path import join, exists + +import pandas as pd +from datasets import load_dataset, DatasetDict +from fastfit import FastFitTrainer, FastFit, sample_dataset +from transformers import AutoTokenizer, pipeline + +from trainable_entity_extractor.data.ExtractionData import ExtractionData +from trainable_entity_extractor.data.Option import Option +from setfit import SetFitModel, TrainingArguments, Trainer + +from trainable_entity_extractor.data.PredictionSample import PredictionSample +from trainable_entity_extractor.extractors.ExtractorBase import ExtractorBase +from trainable_entity_extractor.extractors.bert_method_scripts.AvoidAllEvaluation import AvoidAllEvaluation +from trainable_entity_extractor.extractors.bert_method_scripts.EarlyStoppingAfterInitialTraining import ( + EarlyStoppingAfterInitialTraining, +) +from trainable_entity_extractor.extractors.bert_method_scripts.get_batch_size import get_batch_size, get_max_steps +from trainable_entity_extractor.extractors.text_to_multi_option_extractor.TextToMultiOptionMethod import ( + TextToMultiOptionMethod, +) + + +class TextFastFit(TextToMultiOptionMethod): + + model_name = "sentence-transformers/paraphrase-mpnet-base-v2" + + def can_be_used(self, extraction_data: ExtractionData) -> bool: + if extraction_data.multi_value: + return False + + return True + + def get_data_path(self): + model_folder_path = join(self.extraction_identifier.get_path(), self.get_name()) + + if not exists(model_folder_path): + os.makedirs(model_folder_path) + + return join(model_folder_path, "data.csv") + + def get_model_path(self): + model_folder_path = join(self.extraction_identifier.get_path(), self.get_name()) + + if not exists(model_folder_path): + os.makedirs(model_folder_path) + + model_path = join(model_folder_path, "fast_fit_model") + + os.makedirs(model_path, exist_ok=True) + + return str(model_path) + + @staticmethod + def eval_encodings(example): + example["label"] = eval(example["label"]) + return example + + def get_dataset_from_data(self, extraction_data: ExtractionData): + data = list() + texts = [self.get_text(sample.labeled_data.source_text) for sample in extraction_data.samples] + labels = list() + + for sample in extraction_data.samples: + labels.append("no_label") + if sample.labeled_data.values: + labels[-1] = self.options[self.options.index(sample.labeled_data.values[0])].label + + for text, label in zip(texts[:10000], labels[:10000]): + data.append([text, label]) + + df = pd.DataFrame(data) + df.columns = ["text", "label"] + + df.to_csv(self.get_data_path()) + dataset_csv = load_dataset("csv", data_files=self.get_data_path()) + dataset = dataset_csv["train"] + + return dataset + + def train(self, extraction_data: ExtractionData): + shutil.rmtree(self.get_model_path(), ignore_errors=True) + + train_dataset = self.get_dataset_from_data(extraction_data) + dataset_dict = DatasetDict() + dataset_dict["train"] = train_dataset + dataset_dict["validation"] = train_dataset + dataset_dict["test"] = train_dataset + trainer = FastFitTrainer( + model_name_or_path="sentence-transformers/paraphrase-mpnet-base-v2", + label_column_name="label", + text_column_name="text", + output_dir=self.get_model_path(), + max_steps=get_max_steps(len(extraction_data.samples)), + evaluation_strategy="steps", + save_strategy="steps", + eval_steps=20000, + save_steps=20000, + load_best_model_at_end=True, + max_text_length=128, + dataloader_drop_last=False, + num_repeats=4, + optim="adafactor", + clf_loss_factor=0.1, + fp16=True, + dataset=dataset_dict, + ) + + model = trainer.train() + model.save_pretrained(self.get_model_path()) + + def predict(self, predictions_samples: list[PredictionSample]) -> list[list[Option]]: + model = FastFit.from_pretrained(self.get_model_path()) + tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") + classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) + + texts = [self.get_text(sample.source_text) for sample in predictions_samples] + predictions = list() + for text in texts: + prediction = classifier(text) + prediction_labels = [x["label"] for x in prediction if x["score"] > 0.5] + predictions.append([x for x in self.options if x.label in prediction_labels]) + + return predictions diff --git a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt index cab7ab8..9986298 100644 --- a/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt +++ b/src/trainable_entity_extractor/extractors/text_to_multi_option_extractor/results/results.txt @@ -11,6 +11,7 @@ Extractor cejil_countries 5.5 94.12% Extractor cyrilla_keywords [TextSetFit] 4.0 68.29% Extractor cejil_judge 0.0 100.0% Extractor action [TextSingleLabelSetFit] 10.1 88.3% +Extractor action [FastFit] 174.0 90.5% Extractor issues [Setfit] 27.7 78.3% Extractor issues [BertBase] - 70.8% Extractor issues [BertLarge] 98.3 90.2%