Skip to content

Commit

Permalink
Add fast fit and near fast segment selector
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Dec 2, 2024
1 parent 923931e commit ed439b0
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 33 deletions.
8 changes: 7 additions & 1 deletion src/text_to_multi_option_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from time import time

import rich
from fastfit import FastFit
from sklearn.metrics import f1_score
from transformers import AutoTokenizer, pipeline

from trainable_entity_extractor.data.ExtractionData import ExtractionData
from trainable_entity_extractor.data.ExtractionIdentifier import ExtractionIdentifier
Expand Down Expand Up @@ -73,7 +75,7 @@ def get_benchmark():
# action
# themes
# issues
extractions_data: list[ExtractionData] = get_extraction_data(filter_by=["issues"])
extractions_data: list[ExtractionData] = get_extraction_data(filter_by=["action"])
for extraction_data in extractions_data:
start = time()
extractor = TextToMultiOptionExtractor(extraction_identifier=extraction_data.extraction_identifier)
Expand Down Expand Up @@ -131,3 +133,7 @@ def check_results():
if __name__ == "__main__":
get_benchmark()
# check_results()
# model = FastFit.from_pretrained("/home/gabo/ssd/projects/trainable-entity-extractor/models_data/text_benchmark/action/TextFastFit/fast_fit_model")
# tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, trust_remote_code=True)
# print('oh!')
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextFastSegmentSelector import (
PdfToTextFastSegmentSelector,
)
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextNearFastSegmentSelector import (
PdfToTextNearFastSegmentSelector,
)
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextRegexMethod import PdfToTextRegexMethod
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextSegmentSelector import (
PdfToTextSegmentSelector,
Expand Down Expand Up @@ -42,15 +45,22 @@ class PdfToTextExtractor(ToTextExtractor):
fast_segment_selector_methods = [
pdf_to_text_method_builder(PdfToTextFastSegmentSelector, x) for x in text_to_text_methods
]

near_fast_segment_selector_methods = [
pdf_to_text_method_builder(PdfToTextNearFastSegmentSelector, x) for x in text_to_text_methods
]
segment_selector_methods = [pdf_to_text_method_builder(PdfToTextSegmentSelector, x) for x in text_to_text_methods]
t5_methods = [
pdf_to_text_method_builder(PdfToTextFastSegmentSelector, MT5TrueCaseEnglishSpanishMethod),
pdf_to_text_method_builder(PdfToTextSegmentSelector, MT5TrueCaseEnglishSpanishMethod),
]

METHODS: list[type[ToTextExtractorMethod]] = (
stand_alone_methods + fast_segment_selector_methods + segment_selector_methods + t5_methods
)
METHODS: list[type[ToTextExtractorMethod]] = list()
METHODS += stand_alone_methods
METHODS += fast_segment_selector_methods
METHODS += near_fast_segment_selector_methods
METHODS += segment_selector_methods
METHODS += t5_methods

def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:
if not extraction_data or not extraction_data.samples:
Expand All @@ -65,32 +75,34 @@ def create_model(self, extraction_data: ExtractionData) -> tuple[bool, str]:
def get_train_test_sets(extraction_data: ExtractionData) -> (ExtractionData, ExtractionData):
samples_with_label_segments_boxes = [x for x in extraction_data.samples if x.labeled_data.label_segments_boxes]

if len(samples_with_label_segments_boxes) < 2 and len(extraction_data.samples) > 10:
if len(extraction_data.samples) <= 10:
return extraction_data, extraction_data

if len(samples_with_label_segments_boxes) < 2:
return PdfToTextExtractor.split_80_20(extraction_data)

if len(samples_with_label_segments_boxes) < 10:
test_extraction_data = ExtractorBase.get_extraction_data_from_samples(
train_extraction_data = ExtractorBase.get_extraction_data_from_samples(
extraction_data, samples_with_label_segments_boxes
)
return extraction_data, test_extraction_data
return train_extraction_data, extraction_data

samples_without_label_segments_boxes = [
x for x in extraction_data.samples if not x.labeled_data.label_segments_boxes
]

train_size = int(len(samples_with_label_segments_boxes) * 0.8)
train_set: list[TrainingSample] = (
samples_with_label_segments_boxes[:train_size] + samples_without_label_segments_boxes
)
train_size = int(len(extraction_data.samples) * 0.7)

if len(extraction_data.samples) < 15:
test_set: list[TrainingSample] = samples_with_label_segments_boxes[-10:]
else:
test_set = samples_with_label_segments_boxes[train_size:]
if len(samples_with_label_segments_boxes) >= train_size:
train_extraction_data = ExtractorBase.get_extraction_data_from_samples(
extraction_data, samples_with_label_segments_boxes[:train_size]
)
test_extraction_data = ExtractorBase.get_extraction_data_from_samples(
extraction_data, samples_with_label_segments_boxes[train_size:] + samples_without_label_segments_boxes
)
return train_extraction_data, test_extraction_data

train_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, train_set)
test_extraction_data = ExtractorBase.get_extraction_data_from_samples(extraction_data, test_set)
return train_extraction_data, test_extraction_data
return PdfToTextExtractor.split_80_20(extraction_data)

@staticmethod
def split_80_20(extraction_data):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from trainable_entity_extractor.data.PdfDataSegment import PdfDataSegment
from trainable_entity_extractor.data.PredictionSample import PredictionSample
from trainable_entity_extractor.extractors.ToTextExtractorMethod import ToTextExtractorMethod
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextFastSegmentSelector import (
PdfToTextFastSegmentSelector,
)
from trainable_entity_extractor.extractors.pdf_to_text_extractor.methods.PdfToTextSegmentSelector import (
PdfToTextSegmentSelector,
)

from trainable_entity_extractor.extractors.segment_selector.FastAndPositionsSegmentSelector import (
FastAndPositionsSegmentSelector,
)
from trainable_entity_extractor.extractors.segment_selector.NearFastSegmentSelector import NearFastSegmentSelector


class PdfToTextNearFastSegmentSelector(PdfToTextFastSegmentSelector):

def create_segment_selector_model(self, extraction_data):
segments = list()

for sample in extraction_data.samples:
segments.extend(sample.pdf_data.pdf_data_segments)

fast_segment_selector = NearFastSegmentSelector(self.extraction_identifier)
fast_segment_selector.create_model(segments=segments)
return True, ""

def predict(self, predictions_samples: list[PredictionSample]) -> list[str]:
if not predictions_samples:
return [""] * len(predictions_samples)

fast_segment_selector = NearFastSegmentSelector(self.extraction_identifier)

for sample in predictions_samples:
selected_segments = fast_segment_selector.predict(sample.pdf_data.pdf_data_segments)
self.mark_predicted_segments(selected_segments)
sample.segment_selector_texts = self.get_predicted_texts(sample.pdf_data)

semantic_metadata_extraction = self.SEMANTIC_METHOD(self.extraction_identifier, self.get_name())
return semantic_metadata_extraction.predict(predictions_samples)
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def test_get_train_test_with_few_samples(self):

self.assertEqual(train_set.extraction_identifier, self.extraction_identifier)
self.assertEqual(test_set.extraction_identifier, self.extraction_identifier)
self.assertEqual(len(train_set.samples), 9)
self.assertEqual(len(test_set.samples), 4)
self.assertEqual(9, len(train_set.samples))
self.assertEqual(9, len(test_set.samples))

def test_get_train_test_without_enough_labeled_segments(self):
pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier)
Expand All @@ -55,8 +55,8 @@ def test_get_train_test_without_enough_labeled_segments(self):

train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data)

self.assertEqual(len(train_set.samples), 20)
self.assertEqual(len(test_set.samples), 9)
self.assertEqual(9, len(train_set.samples))
self.assertEqual(20, len(test_set.samples))

def test_get_train_test_without_labeled_segments(self):
pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier)
Expand All @@ -68,8 +68,8 @@ def test_get_train_test_without_labeled_segments(self):

train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data)

self.assertEqual(len(train_set.samples), 80)
self.assertEqual(len(test_set.samples), 20)
self.assertEqual(80, len(train_set.samples))
self.assertEqual(20, len(test_set.samples))

def test_get_train_test_without_enough_data(self):
pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier)
Expand All @@ -81,8 +81,8 @@ def test_get_train_test_without_enough_data(self):

train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data)

self.assertEqual(len(train_set.samples), 180)
self.assertEqual(len(test_set.samples), 20)
self.assertEqual(160, len(train_set.samples))
self.assertEqual(40, len(test_set.samples))

def test_get_train_test_only_labels_with_segments(self):
pdf_to_text_extractor = PdfToTextExtractor(self.extraction_identifier)
Expand All @@ -94,5 +94,5 @@ def test_get_train_test_only_labels_with_segments(self):

train_set, test_set = pdf_to_text_extractor.get_train_test_sets(extraction_data)

self.assertEqual(len(train_set.samples), 160)
self.assertEqual(len(test_set.samples), 40)
self.assertEqual(140, len(train_set.samples))
self.assertEqual(60, len(test_set.samples))
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def predict(self, segments):
model = lgb.Booster(model_file=self.model_path)
predictions = model.predict(x)

return [segment for i, segment in enumerate(segments) if predictions[i] > 0.5]
return self.predictions_scores_to_segments(segments, predictions)

@staticmethod
def predictions_scores_to_segments(segments: list[PdfDataSegment], prediction_scores: list[float]):
return [segment for i, segment in enumerate(segments) if prediction_scores[i] > 0.5]

def load_repeated_words(self):
self.previous_words = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from trainable_entity_extractor.data.PdfDataSegment import PdfDataSegment
from trainable_entity_extractor.extractors.segment_selector.FastAndPositionsSegmentSelector import (
FastAndPositionsSegmentSelector,
)


class NearFastSegmentSelector(FastAndPositionsSegmentSelector):
@staticmethod
def predictions_scores_to_segments(segments: list[PdfDataSegment], prediction_scores: list[float]):
predicted_segments = []
for i, (segment, prediction) in enumerate(zip(segments, prediction_scores)):
if prediction > 0.5:
predicted_segments.append(segment)
continue

if len(prediction_scores) >= i + 1 and prediction_scores[i + 1] > 0.5:
predicted_segments.append(segment)
continue

if i != 0 and prediction_scores[i - 1] > 0.5:
predicted_segments.append(segment)

return predicted_segments
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import json
import os
from os.path import join, exists
from pathlib import Path
from typing import Type

from trainable_entity_extractor.config import config_logger
Expand Down Expand Up @@ -41,7 +38,6 @@
from trainable_entity_extractor.extractors.text_to_multi_option_extractor.methods.TextSingleLabelSetFitMultilingual import (
TextSingleLabelSetFitMultilingual,
)
from trainable_entity_extractor.extractors.text_to_multi_option_extractor.methods.TextTfIdf import TextTfIdf


class TextToMultiOptionExtractor(ExtractorBase):
Expand Down Expand Up @@ -139,6 +135,7 @@ def get_performance(extraction_data, method_instance):
try:
performance = method_instance.performance(extraction_data)
except:
config_logger.info("ERROR", exc_info=True)
performance = 0
config_logger.info(f"\nPerformance {method_instance.get_name()}: {performance}%")
return performance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import shutil
from os.path import join, exists

import pandas as pd
from datasets import load_dataset, DatasetDict
from fastfit import FastFitTrainer, FastFit, sample_dataset
from transformers import AutoTokenizer, pipeline

from trainable_entity_extractor.data.ExtractionData import ExtractionData
from trainable_entity_extractor.data.Option import Option
from setfit import SetFitModel, TrainingArguments, Trainer

from trainable_entity_extractor.data.PredictionSample import PredictionSample
from trainable_entity_extractor.extractors.ExtractorBase import ExtractorBase
from trainable_entity_extractor.extractors.bert_method_scripts.AvoidAllEvaluation import AvoidAllEvaluation
from trainable_entity_extractor.extractors.bert_method_scripts.EarlyStoppingAfterInitialTraining import (
EarlyStoppingAfterInitialTraining,
)
from trainable_entity_extractor.extractors.bert_method_scripts.get_batch_size import get_batch_size, get_max_steps
from trainable_entity_extractor.extractors.text_to_multi_option_extractor.TextToMultiOptionMethod import (
TextToMultiOptionMethod,
)


class TextFastFit(TextToMultiOptionMethod):

model_name = "sentence-transformers/paraphrase-mpnet-base-v2"

def can_be_used(self, extraction_data: ExtractionData) -> bool:
if extraction_data.multi_value:
return False

return True

def get_data_path(self):
model_folder_path = join(self.extraction_identifier.get_path(), self.get_name())

if not exists(model_folder_path):
os.makedirs(model_folder_path)

return join(model_folder_path, "data.csv")

def get_model_path(self):
model_folder_path = join(self.extraction_identifier.get_path(), self.get_name())

if not exists(model_folder_path):
os.makedirs(model_folder_path)

model_path = join(model_folder_path, "fast_fit_model")

os.makedirs(model_path, exist_ok=True)

return str(model_path)

@staticmethod
def eval_encodings(example):
example["label"] = eval(example["label"])
return example

def get_dataset_from_data(self, extraction_data: ExtractionData):
data = list()
texts = [self.get_text(sample.labeled_data.source_text) for sample in extraction_data.samples]
labels = list()

for sample in extraction_data.samples:
labels.append("no_label")
if sample.labeled_data.values:
labels[-1] = self.options[self.options.index(sample.labeled_data.values[0])].label

for text, label in zip(texts[:10000], labels[:10000]):
data.append([text, label])

df = pd.DataFrame(data)
df.columns = ["text", "label"]

df.to_csv(self.get_data_path())
dataset_csv = load_dataset("csv", data_files=self.get_data_path())
dataset = dataset_csv["train"]

return dataset

def train(self, extraction_data: ExtractionData):
shutil.rmtree(self.get_model_path(), ignore_errors=True)

train_dataset = self.get_dataset_from_data(extraction_data)
dataset_dict = DatasetDict()
dataset_dict["train"] = train_dataset
dataset_dict["validation"] = train_dataset
dataset_dict["test"] = train_dataset
trainer = FastFitTrainer(
model_name_or_path="sentence-transformers/paraphrase-mpnet-base-v2",
label_column_name="label",
text_column_name="text",
output_dir=self.get_model_path(),
max_steps=get_max_steps(len(extraction_data.samples)),
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=20000,
save_steps=20000,
load_best_model_at_end=True,
max_text_length=128,
dataloader_drop_last=False,
num_repeats=4,
optim="adafactor",
clf_loss_factor=0.1,
fp16=True,
dataset=dataset_dict,
)

model = trainer.train()
model.save_pretrained(self.get_model_path())

def predict(self, predictions_samples: list[PredictionSample]) -> list[list[Option]]:
model = FastFit.from_pretrained(self.get_model_path())
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)

texts = [self.get_text(sample.source_text) for sample in predictions_samples]
predictions = list()
for text in texts:
prediction = classifier(text)
prediction_labels = [x["label"] for x in prediction if x["score"] > 0.5]
predictions.append([x for x in self.options if x.label in prediction_labels])

return predictions
Loading

0 comments on commit ed439b0

Please sign in to comment.