Skip to content

Commit

Permalink
Refactored CrossEncoder into our own wrapper class to support head tr…
Browse files Browse the repository at this point in the history
…aining (#88)

* Refactored CrossEncoder into our own wrapper class to support head training

* Fix typo in comment

* fixing tests

* Fixing mypy errors

* Fixing doc build

* Still fixing doc build

* Keep fixing doc build

* minor bug fix

* mypy was updated `(-_-)`

* change type annotation of `pairs`argument

* `_logits_list` -> `_activations_list`

* `get_features` -> `_get_features_or_predictions`

---------

Co-authored-by: Алексеев Илья <44509110+voorhs@users.noreply.github.com>
Co-authored-by: voorhs <ilya_alekseev_2016@list.ru>
  • Loading branch information
3 people authored Jan 11, 2025
1 parent 8a61a6c commit 2bf20ec
Showing 8 changed files with 219 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ lint:

.PHONY: sync
sync:
poetry install --sync --with dev,test,lint,typing,docs
poetry sync

.PHONY: docs
docs:
3 changes: 3 additions & 0 deletions autointent/_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._nli_transformer import NLITransformer

__all__ = ["NLITransformer"]
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""CrossEncoderWithLogreg class for cross-encoder-based binary classification with logistic regression."""
"""NLITransformer class for cross-encoder-based estimation of meaning closeness.
Can be used to rank retrieved sentences by meaning closeness to provided utterance.
"""

import itertools as it
import logging
from pathlib import Path
from random import shuffle
from typing import Any, TypeVar
from typing import Any

import joblib
import numpy as np
import numpy.typing as npt
import torch
from sentence_transformers import CrossEncoder
from sklearn.linear_model import LogisticRegressionCV
from torch import nn

from autointent.custom_types import LabelType

@@ -54,15 +58,13 @@ def construct_samples(
return pairs, labels


CrossEncoderType = TypeVar("CrossEncoderType", bound="CrossEncoderWithLogreg")


class CrossEncoderWithLogreg:
class NLITransformer:
r"""
Cross-encoder with logistic regression for binary classification.
Cross-encoder for NLI.
This class uses a SentenceTransformers CrossEncoder model to extract features
and LogisticRegressionCV for classification.
In the hart this class uses a SentenceTransformers CrossEncoder model to extract features.
Then it uses either the model's clissifier or our custom trained LogisticRegressionCV
(custom classifier layer in the future) to rank documents using similarity score to the query.
:ivar cross_encoder: The CrossEncoder model used to extract features.
:ivar batch_size: Batch size for processing text pairs.
@@ -72,10 +74,8 @@ class CrossEncoderWithLogreg:
Examples
--------
Creating and fitting the CrossEncoderWithLogreg:
>>> from autointent.modules import CrossEncoderWithLogreg
>>> from sentence_transformers import CrossEncoder
>>> model = CrossEncoder("cross-encoder-model")
>>> scorer = CrossEncoderWithLogreg(model)
>>> from autointent._transformers import NLITransformer
>>> scorer = NLITransformer("cross-encoder-model")
>>> utterances = ["What is your name?", "How old are you?"]
>>> labels = [1, 0]
>>> scorer.fit(utterances, labels)
@@ -87,43 +87,64 @@ class CrossEncoderWithLogreg:
Saving and loading the model:
>>> scorer.save("outputs/")
>>> loaded_scorer = CrossEncoderWithLogreg.load("outputs/")
>>> loaded_scorer = NLITransformer.load("outputs/")
"""

def __init__(self, model: CrossEncoder, batch_size: int = 326) -> None:
def __init__(
self,
model: str,
device: str = "cpu",
train_classifier: bool = False,
batch_size: int = 326,
max_length: int | None = None,
classifier_head: LogisticRegressionCV | None = None,
) -> None:
"""
Initialize the CrossEncoderWithLogreg.
Initialize the NLITransformer.
:param model: The CrossEncoder model to use.
:param model: The CrossEncoder model name to use.
:param device: Device to run operations on, e.g., "cpu" or "cuda".
:param train_classifier: Whether to train a custom classifier, defaults to False.
:param batch_size: Batch size for processing text pairs, defaults to 326.
:param max_length (int, optional): Max length for input sequences for the cross encoder.
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
"""
self.cross_encoder = model
self.cross_encoder = CrossEncoder(model, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
self.train_classifier = False
self.batch_size = batch_size
self.max_length = max_length
self._clf = classifier_head

if classifier_head is not None or train_classifier:
self.train_classifier = True
self._activations_list: list[npt.NDArray[Any]] = []
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)

def _classifier_hook(self, _module, input_tensor, _output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
self._activations_list.append(input_tensor[0].cpu().numpy())

@torch.no_grad()
def get_features(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
"""
Extract features from text pairs using the CrossEncoder model.
Extract features or get predictions using the CrossEncoder model.
If :py:attr:`~train_classifier` is ``True``, return raw activations from
cross-encoder transformer. Otherwise, get predictions from cross-encoder head.
:param pairs: List of text pairs.
:return: Numpy array of extracted features.
"""
logits_list: list[npt.NDArray[Any]] = []

def hook_function(module, input_tensor, output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ARG001, ANN001
logits_list.append(input_tensor[0].cpu().numpy())
if not self.train_classifier:
return np.array(self.cross_encoder.predict(pairs, batch_size=self.batch_size, activation_fct=nn.Sigmoid()))

handler = self.cross_encoder.model.classifier.register_forward_hook(hook_function)
# put the data through, features will be taken in the hook
self.cross_encoder.predict(pairs, batch_size=self.batch_size)

for i in range(0, len(pairs), self.batch_size):
batch = pairs[i : i + self.batch_size]
self.cross_encoder.predict(batch)
res = np.concatenate(self._activations_list, axis=0)
self._activations_list.clear()
return res # type: ignore[no-any-return]

handler.remove()

return np.concatenate(logits_list, axis=0)

def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
"""
Train the logistic regression model on cross-encoder features.
@@ -137,8 +158,10 @@ def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
logger.error(msg)
raise ValueError(msg)

features = self.get_features(pairs)
features = self._get_features_or_predictions(pairs)

# TODO: LogisticRegressionCV has class_weight="balanced". Is it better to use it instead of balance_factor in
# construct_samples?
clf = LogisticRegressionCV()
clf.fit(features, labels)

@@ -151,18 +174,53 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
:param utterances: List of utterances (texts).
:param labels: Intent class labels corresponding to the utterances.
"""
if not self.train_classifier:
return # do nothing if the classifier is not to be re-trained

pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
self._fit(pairs, labels_) # type: ignore[arg-type]

def predict(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
"""
Predict probabilities of two utterances having the same intent label.
:param pairs: List of text pairs to classify.
:return: Numpy array of probabilities.
"""
features = self.get_features(pairs)
return self._clf.predict_proba(features)[:, 1] # type: ignore[no-any-return]
if self.train_classifier and self._clf is None:
msg = "Classifier is not trained yet"
raise ValueError(msg)

features = self._get_features_or_predictions(pairs)

if self._clf is not None:
return np.array(self._clf.predict_proba(features)[:, 1])

return features

def rank(
self,
query: str,
query_docs: list[str],
top_k: int | None = None,
) -> list[dict[str, Any]]:
"""
Rank documents according to meaning closeness to the query.
:param query: The reference document.
:query_docs: List of documents to rank
:top_k: how many document to return
:return: array of dictionaries of ranked items.
"""
query_doc_pairs = [(query, doc) for doc in query_docs]
scores = self.predict(query_doc_pairs)

if top_k is None:
top_k = len(query_docs)

results = [{"corpus_id": i, "score": scores[i]} for i in range(len(query_docs))]
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]

def save(self, path: str) -> None:
"""
@@ -178,21 +236,13 @@ def save(self, path: str) -> None:
clf_path = dump_dir / "classifier.joblib"
joblib.dump(self._clf, clf_path)

def set_classifier(self, clf: LogisticRegressionCV) -> None:
"""
Set the logistic regression classifier.
:param clf: LogisticRegressionCV instance.
"""
self._clf = clf

@classmethod
def load(cls, path: str) -> "CrossEncoderWithLogreg":
def load(cls, path: str) -> "NLITransformer":
"""
Load the model and classifier from disk.
:param path: Directory path containing the saved model and classifier.
:return: Initialized CrossEncoderWithLogreg instance.
:return: Initialized NLITransformer instance.
"""
dump_dir = Path(path)

@@ -202,9 +252,5 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":

# Load sentence transformer model
crossencoder_dir = str(dump_dir / "crossencoder")
model = CrossEncoder(crossencoder_dir)

res = cls(model)
res.set_classifier(clf)

return res
return cls(crossencoder_dir, classifier_head=clf)
27 changes: 17 additions & 10 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
@@ -30,17 +30,16 @@ class DataHandler:
"""Data handler class."""

def __init__(
self,
dataset: Dataset,
force_multilabel: bool = False,
random_seed: int = 0,
self, dataset: Dataset, force_multilabel: bool = False, random_seed: int = 0, split_train: bool = True
) -> None:
"""
Initialize the data handler.
:param dataset: Training dataset.
:param force_multilabel: If True, force the dataset to be multilabel.
:param random_seed: Seed for random number generation.
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
threshold search).
"""
set_seed(random_seed)

@@ -50,7 +49,7 @@ def __init__(

self.n_classes = self.dataset.n_classes

self._split(random_seed)
self._split(random_seed, split_train)

self.regexp_patterns = [
RegexPatterns(
@@ -191,11 +190,11 @@ def dump(self, filepath: str | Path) -> None:
"""
self.dataset.to_json(filepath)

def _split(self, random_seed: int) -> None:
def _split(self, random_seed: int, split_train: bool) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
has_test_split = any(split.startswith(Split.TEST) for split in self.dataset)

if Split.TRAIN in self.dataset:
if split_train and Split.TRAIN in self.dataset:
self._split_train(random_seed)

if Split.TEST not in self.dataset:
@@ -252,13 +251,21 @@ def _split_validation_from_test(self, random_seed: int) -> None:
)

def _split_validation_from_train(self, random_seed: int) -> None:
for idx in range(2):
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
if Split.TRAIN in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_{idx}",
split=Split.TRAIN,
test_size=0.2,
random_seed=random_seed,
)
else:
for idx in range(2):
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_{idx}",
test_size=0.2,
random_seed=random_seed,
)

def _split_test(self, test_size: float, random_seed: int) -> None:
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
Loading

0 comments on commit 2bf20ec

Please sign in to comment.