Skip to content

Commit

Permalink
Feat/logreg retrieval (#81)
Browse files Browse the repository at this point in the history
* feat: added logregretrieval

* fix: predict for logregretrieval

* fix: predict for logregretrieval

* fix: added kwargs

* fixed: dump & load modules and added tests

* fix: fixed dump for RetrieverEmbedding

* fix: fixed docstring

* fix: change vector_index to embedder

* fix: change to the ScoringMetricFn

* fix: multilabel and fix scorer metric

* fix: load and dump

* fix: lint

* fix: lint

* fix: mypy

* fix: docs

* fix: docs

* feat: change predict in RetrievalEmbedding

* feat: change predict in RetrievalEmbedding

* feat: update logregembedding

* feat: update docstring

* fix: fixed retrieval test

* fix: fixed retrieval and logreg test

* fix: added cv to the docs example

* fix: fixed score func

* fix: added accuracy for scorer in logreg

* fix: added predict_proba

* test: update tests

* feat: divide retrieval and logreg

* fix: fixed setup_environment

* fix: fixed import

* fix: deleted dump and load

* fix: rename classifier and label encoder

* fix: fixed multilabel

* feat: updated tests

* feat: update multiclass.yaml

* fix: added cv

* fix: lint

* fix: fixed metric in multilabel.yaml

* fix: fixed _classifier

* fix: fixed label encoder

* fix: fixed scoring

* fix: fixed split in score

* fix: fixed split in score

* fix: type

* feat: updated predict() in logreg

* feat: updated test

* fix: fixed lint

* fix: no-any return

* make changes

* fix

* remove `k` completely

* remove k from search space

* fix another `k` issue

* finally?

---------

Co-authored-by: voorhs <[email protected]>
  • Loading branch information
Darinochka and voorhs authored Jan 28, 2025
1 parent 1ff18cf commit 36214e5
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 61 deletions.
6 changes: 4 additions & 2 deletions autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ThresholdDecision,
TunableDecision,
)
from .embedding import RetrievalEmbedding
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer

T = TypeVar("T", bound=Module)
Expand All @@ -20,7 +20,9 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
return {module.name: module for module in modules}


RETRIEVAL_MODULES_MULTICLASS: dict[str, type[EmbeddingModule]] = _create_modules_dict([RetrievalEmbedding])
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[EmbeddingModule]] = _create_modules_dict(
[RetrievalAimedEmbedding, LogregAimedEmbedding]
)

RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS

Expand Down
8 changes: 0 additions & 8 deletions autointent/modules/abc/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,3 @@

class EmbeddingModule(Module, ABC):
"""Base class for embedding modules."""

def __init__(self, k: int) -> None:
"""
Initialize embedding module.
:param k: number of closest neighbors to consider during inference
"""
self.k = k
5 changes: 3 additions & 2 deletions autointent/modules/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""These modules are used only for optimization as they use proxy metrics for choosing best embedding model."""

from ._retrieval import RetrievalEmbedding
from ._logreg import LogregAimedEmbedding
from ._retrieval import RetrievalAimedEmbedding

__all__ = ["RetrievalEmbedding"]
__all__ = ["LogregAimedEmbedding", "RetrievalAimedEmbedding"]
173 changes: 173 additions & 0 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""LogregAimedEmbedding class for a proxy optimzation of embedding."""

from typing import Literal

import numpy as np
from numpy.typing import NDArray
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.multioutput import MultiOutputClassifier
from sklearn.preprocessing import LabelEncoder

from autointent import Context, Embedder
from autointent.context.optimization_info import RetrieverArtifact
from autointent.custom_types import ListOfLabels
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
from autointent.modules.abc import EmbeddingModule


class LogregAimedEmbedding(EmbeddingModule):
r"""
Module for configuring embeddings optimized for linear classification.
The main purpose of this module is to be used at embedding node for optimizing
embedding configuration using its logreg classification quality as a sort of proxy metric.
:ivar classifier: The trained logistic regression model.
:ivar label_encoder: Label encoder for converting labels to numerical format.
:ivar name: Name of the module, defaults to "logreg".
Examples
--------
.. testcode::
from autointent.modules.embedding import LogregAimedEmbedding
utterances = ["bye", "how are you?", "good morning"]
labels = [0, 1, 1]
retrieval = LogregAimedEmbedding(
embedder_name="sergeyzh/rubert-tiny-turbo",
cv=2
)
retrieval.fit(utterances, labels)
"""

_classifier: LogisticRegressionCV | MultiOutputClassifier
_label_encoder: LabelEncoder | None
name = "logreg"
supports_multiclass = True
supports_multilabel = True
supports_oos = False

def __init__(
self,
embedder_name: str,
cv: int = 3,
embedder_device: str = "cpu",
embedder_batch_size: int = 32,
embedder_max_length: int | None = None,
embedder_use_cache: bool = True,
) -> None:
"""
Initialize the LogregAimedEmbedding.
:param cv: the number of folds used in LogisticRegressionCV
:param embedder_name: Name of the embedder used for creating embeddings.
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
:param batch_size: Batch size for embedding generation.
:param max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
"""
self.embedder_name = embedder_name
self.embedder_device = embedder_device
self.embedder_batch_size = embedder_batch_size
self.embedder_max_length = embedder_max_length
self.embedder_use_cache = embedder_use_cache
self.cv = cv

@classmethod
def from_context(
cls,
context: Context,
cv: int,
embedder_name: str,
) -> "LogregAimedEmbedding":
"""
Create a LogregAimedEmbedding instance using a Context object.
:param cv: the number of folds used in LogisticRegressionCV
:param context: The context containing configurations and utilities.
:param embedder_name: Name of the embedder to use.
:return: Initialized LogregAimedEmbedding instance.
"""
return cls(
cv=cv,
embedder_name=embedder_name,
embedder_device=context.get_device(),
embedder_batch_size=context.get_batch_size(),
embedder_max_length=context.get_max_length(),
embedder_use_cache=context.get_use_cache(),
)

def clear_cache(self) -> None:
pass

def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
"""
Train the logistic regression model using the provided utterances and labels.
:param utterances: List of text data to index.
:param labels: List of corresponding labels for the utterances.
"""
self._validate_task(labels)

self._embedder = Embedder(
device=self.embedder_device,
model_name_or_path=self.embedder_name,
batch_size=self.embedder_batch_size,
max_length=self.embedder_max_length,
use_cache=self.embedder_use_cache,
)
embeddings = self._embedder.embed(utterances)

if self._multilabel:
self._label_encoder = None
base_clf = LogisticRegression()
self._classifier = MultiOutputClassifier(base_clf)
else:
self._label_encoder = LabelEncoder()
labels = self._label_encoder.fit_transform(labels)
self._classifier = LogisticRegressionCV(cv=self.cv)

self._classifier.fit(embeddings, labels)

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
"""
Evaluate the embedding model using a specified metric function.
:param context: The context containing test data and labels.
:param split: Target split
:return: Computed metrics value for the test set or error code of metrics
"""
if split == "validation":
utterances = context.data_handler.validation_utterances(0)
labels = context.data_handler.validation_labels(0)
elif split == "test":
utterances = context.data_handler.test_utterances()
labels = context.data_handler.test_labels()
else:
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
raise ValueError(message)

probas = self.predict(utterances)
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
return self.score_metrics((labels, probas), metrics_dict)

def get_assets(self) -> RetrieverArtifact:
"""
Get the classifier artifacts for this module.
:return: A RetrieverArtifact object containing embedder information.
"""
return RetrieverArtifact(embedder_name=self.embedder_name)

def predict(self, utterances: list[str]) -> NDArray[np.float64]:
embeddings = self._embedder.embed(utterances)
probas = self._classifier.predict_proba(embeddings)

if self._multilabel:
probas = np.stack(probas, axis=1)[..., 1]

return probas # type: ignore[no-any-return]
56 changes: 18 additions & 38 deletions autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""RetrievalEmbedding class for managing and interacting with a vector database for retrieval tasks."""
"""RetrievalAimedEmbedding class for a proxy optimization of embedding."""

from pathlib import Path
from typing import Literal

from autointent import VectorIndex
from autointent.context import Context
from autointent import Context, VectorIndex
from autointent.context.optimization_info import RetrieverArtifact
from autointent.custom_types import ListOfLabels
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
from autointent.modules.abc import EmbeddingModule


class RetrievalEmbedding(EmbeddingModule):
class RetrievalAimedEmbedding(EmbeddingModule):
r"""
Module for managing retrieval operations using a vector database.
Module for configuring embeddings optimized for retrieval tasks.
RetrievalEmbedding provides methods for indexing, querying, and managing a vector database for tasks
such as nearest neighbor retrieval.
The main purpose of this module is to be used at embedding node for optimizing
embedding configuration using its retrieval quality as a sort of proxy metric.
:ivar vector_index: The vector index used for nearest neighbor retrieval.
:ivar name: Name of the module, defaults to "retrieval".
Expand All @@ -26,25 +24,22 @@ class RetrievalEmbedding(EmbeddingModule):
.. testcode::
from autointent.modules.embedding import RetrievalEmbedding
from autointent.modules.embedding import RetrievalAimedEmbedding
utterances = ["bye", "how are you?", "good morning"]
labels = [0, 1, 1]
retrieval = RetrievalEmbedding(
retrieval = RetrievalAimedEmbedding(
k=2,
embedder_name="sergeyzh/rubert-tiny-turbo",
)
retrieval.fit(utterances, labels)
predictions = retrieval.predict(["how is the weather today?"])
print(predictions)
.. testoutput::
([[1, 1]], [[0.1525942087173462, 0.18616724014282227]], [['good morning', 'how are you?']])
"""

_vector_index: VectorIndex
name = "retrieval"
supports_multiclass = True
supports_multilabel = True
supports_oos = False

def __init__(
self,
Expand All @@ -56,7 +51,7 @@ def __init__(
embedder_use_cache: bool = True,
) -> None:
"""
Initialize the RetrievalEmbedding.
Initialize the RetrievalAimedEmbedding.
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder used for creating embeddings.
Expand All @@ -65,28 +60,27 @@ def __init__(
:param max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
"""
self.k = k
self.embedder_name = embedder_name
self.embedder_device = embedder_device
self.embedder_batch_size = embedder_batch_size
self.embedder_max_length = embedder_max_length
self.embedder_use_cache = embedder_use_cache

super().__init__(k=k)

@classmethod
def from_context(
cls,
context: Context,
k: int,
embedder_name: str,
) -> "RetrievalEmbedding":
) -> "RetrievalAimedEmbedding":
"""
Create a RetrievalEmbedding instance using a Context object.
Create an instance using a Context object.
:param context: The context containing configurations and utilities.
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder to use.
:return: Initialized RetrievalEmbedding instance.
:return: Initialized RetrievalAimedEmbedding instance.
"""
return cls(
k=k,
Expand All @@ -104,6 +98,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
:param utterances: List of text data to index.
:param labels: List of corresponding labels for the utterances.
"""
self._validate_task(labels)

self._vector_index = VectorIndex(
self.embedder_name,
self.embedder_device,
Expand Down Expand Up @@ -151,22 +147,6 @@ def clear_cache(self) -> None:
"""Clear cached data in memory used by the vector index."""
self._vector_index.clear_ram()

def dump(self, path: str) -> None:
"""
Save the module's metadata and vector index to a specified directory.
:param path: Path to the directory where assets will be dumped.
"""
self._vector_index.dump(Path(path))

def load(self, path: str) -> None:
"""
Load the module's metadata and vector index from a specified directory.
:param path: Path to the directory containing the dumped assets.
"""
self._vector_index = VectorIndex.load(Path(path))

def predict(self, utterances: list[str]) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]:
"""
Predict the nearest neighbors for a list of utterances.
Expand Down
8 changes: 4 additions & 4 deletions tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
- node_type: embedding
metric: retrieval_hit_rate_intersecting
metric: scoring_accuracy
search_space:
- module_name: retrieval
k: [10]
- module_name: logreg
cv: [2]
embedder_name:
- sentence-transformers/all-MiniLM-L6-v2
- avsolatorio/GIST-small-Embedding-v0
Expand Down Expand Up @@ -33,4 +33,4 @@
- module_name: threshold
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_name: tunable
- module_name: adaptive
- module_name: adaptive
File renamed without changes.
Loading

0 comments on commit 36214e5

Please sign in to comment.