Skip to content

Commit

Permalink
Add config validation (#104)
Browse files Browse the repository at this point in the history
* init validation

* init validation

* add validation

* add validation to pipeline

* simplify fields

* make module name literal

* fix fields

* fix task types

* update name and metrics

* update

* add config tests

* update metrics

* fix

* fix naming

* fix docs
  • Loading branch information
Samoed authored Jan 31, 2025
1 parent 23efe32 commit 599794b
Show file tree
Hide file tree
Showing 25 changed files with 561 additions and 61 deletions.
8 changes: 7 additions & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autointent.custom_types import ListOfGenericLabels, NodeType
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
from autointent.utils import load_default_search_space, load_search_space

from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
Expand Down Expand Up @@ -72,10 +73,12 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
Create pipeline optimizer from dictionary search space.
:param search_space: Dictionary config
:param seed: random seed
"""
if isinstance(search_space, Path | str):
search_space = load_search_space(search_space)
nodes = [NodeOptimizer(**node) for node in search_space]
validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type]
nodes = [NodeOptimizer(**node) for node in validated_search_space]
return cls(nodes=nodes, seed=seed)

@classmethod
Expand All @@ -84,6 +87,9 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
Create pipeline optimizer with default search space for given classification task.
:param multilabel: Whether the task multi-label, or single-label.
:param seed: random seed
:return: Pipeline
"""
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)

Expand Down
6 changes: 1 addition & 5 deletions autointent/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from enum import Enum
from typing import Literal, TypeAlias, TypedDict
from typing import Literal, TypeAlias


class LogLevel(Enum):
Expand Down Expand Up @@ -46,10 +46,6 @@ class LogLevel(Enum):
"""


class BaseMetadataDict(TypedDict):
"""Base metadata dictionary for storing additional information."""


class NodeType(str, Enum):
"""Enumeration of node types in the AutoIntent pipeline."""

Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
[RetrievalAimedEmbedding, LogregAimedEmbedding]
)

RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
RETRIEVAL_MODULES_MULTILABEL: dict[str, type[EmbeddingModule]] = RETRIEVAL_MODULES_MULTICLASS

SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = _create_modules_dict(
[
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/decision/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class ThresholdDecision(DecisionModule):
ThresholdDecision uses a predefined threshold (or array of thresholds) to predict
labels for single-label or multi-label classification tasks.
:ivar metadata_dict_name: Filename for saving metadata to disk.
:ivar multilabel: If True, the model supports multi-label classification.
:ivar n_classes: Number of classes in the dataset.
:ivar tags: Tags for predictions (if any).
:ivar name: Name of the predictor, defaults to "adaptive".
Expand Down Expand Up @@ -78,17 +75,17 @@ class ThresholdDecision(DecisionModule):

def __init__(
self,
thresh: float | npt.NDArray[Any],
thresh: float | list[float],
) -> None:
"""
Initialize threshold predictor.
:param thresh: Threshold for the scores, shape (n_classes,) or float
"""
self.thresh = thresh
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)

@classmethod
def from_context(cls, context: Context, thresh: float | npt.NDArray[Any] = 0.5) -> "ThresholdDecision":
def from_context(cls, context: Context, thresh: float | list[float] = 0.5) -> "ThresholdDecision":
"""
Initialize from context.
Expand Down
10 changes: 5 additions & 5 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class LogregAimedEmbedding(EmbeddingModule):
The main purpose of this module is to be used at embedding node for optimizing
embedding configuration using its logreg classification quality as a sort of proxy metric.
:ivar classifier: The trained logistic regression model.
:ivar label_encoder: Label encoder for converting labels to numerical format.
:ivar _classifier: The trained logistic regression model.
:ivar _label_encoder: Label encoder for converting labels to numerical format.
:ivar name: Name of the module, defaults to "logreg".
Examples
Expand All @@ -42,7 +42,7 @@ class LogregAimedEmbedding(EmbeddingModule):

_classifier: LogisticRegressionCV | MultiOutputClassifier
_label_encoder: LabelEncoder | None
name = "logreg"
name = "logreg_embedding"
supports_multiclass = True
supports_multilabel = True
supports_oos = False
Expand All @@ -62,8 +62,8 @@ def __init__(
:param cv: the number of folds used in LogisticRegressionCV
:param embedder_name: Name of the embedder used for creating embeddings.
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
:param batch_size: Batch size for embedding generation.
:param max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_batch_size: Batch size for embedding generation.
:param embedder_max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
"""
self.embedder_name = embedder_name
Expand Down
2 changes: 2 additions & 0 deletions autointent/modules/regexp/_regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class RegexPatternsCompiled(TypedDict):
class RegExp(Module):
"""Regular expressions based intent detection module."""

name = "regexp"

@classmethod
def from_context(cls, context: Context) -> "RegExp":
"""Initialize from context."""
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/scoring/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LinearScorer(ScoringModule):
.. testoutput::
[[0.50000032 0.49999968]
[0.50000032 0.49999968]]
[0.44031667 0.55968333]]
"""

Expand Down
2 changes: 2 additions & 0 deletions autointent/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from ._inference_node import InferenceNode
from ._nodes_info import DecisionNodeInfo, EmbeddingNodeInfo, NodeInfo, RegExpNodeInfo, ScoringNodeInfo
from ._optimization import NodeOptimizer
from .schemes import OptimizationConfig

__all__ = [
"DecisionNodeInfo",
"EmbeddingNodeInfo",
"InferenceNode",
"NodeInfo",
"NodeOptimizer",
"OptimizationConfig",
"RegExpNodeInfo",
"ScoringNodeInfo",
]
12 changes: 9 additions & 3 deletions autointent/nodes/_nodes_info/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from autointent.metrics import (
RETRIEVAL_METRICS_MULTICLASS,
RETRIEVAL_METRICS_MULTILABEL,
SCORING_METRICS_MULTICLASS,
SCORING_METRICS_MULTILABEL,
RetrievalMetricFn,
ScoringMetricFn,
)
from autointent.modules import RETRIEVAL_MODULES_MULTICLASS, RETRIEVAL_MODULES_MULTILABEL
from autointent.modules.abc import Module
Expand All @@ -18,12 +21,15 @@
class EmbeddingNodeInfo(NodeInfo):
"""Retrieval node info."""

metrics_available: ClassVar[Mapping[str, RetrievalMetricFn]] = (
RETRIEVAL_METRICS_MULTICLASS | RETRIEVAL_METRICS_MULTILABEL
metrics_available: ClassVar[Mapping[str, RetrievalMetricFn | ScoringMetricFn]] = (
RETRIEVAL_METRICS_MULTICLASS
| RETRIEVAL_METRICS_MULTILABEL
| SCORING_METRICS_MULTILABEL
| SCORING_METRICS_MULTICLASS
)

modules_available: ClassVar[Mapping[str, type[Module]]] = (
RETRIEVAL_MODULES_MULTICLASS | RETRIEVAL_MODULES_MULTILABEL # type: ignore[has-type]
RETRIEVAL_MODULES_MULTICLASS | RETRIEVAL_MODULES_MULTILABEL
)

node_type = NodeType.embedding
10 changes: 5 additions & 5 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(
"""
self.node_type = node_type
self.node_info = NODES_INFO[node_type]
self.decision_metric_name = target_metric
self.target_metric = target_metric

self.metrics = metrics if metrics is not None else []
if self.decision_metric_name not in self.metrics:
self.metrics.append(self.decision_metric_name)
if self.target_metric not in self.metrics:
self.metrics.append(self.target_metric)

self.modules_search_spaces = search_space # TODO search space validation
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
Expand Down Expand Up @@ -73,7 +73,7 @@ def fit(self, context: Context) -> None:

self._logger.debug("scoring %s module...", module_name)
metrics_score = module.score(context, "validation", self.metrics)
metric_value = metrics_score[self.decision_metric_name]
metric_value = metrics_score[self.target_metric]

context.callback_handler.log_metrics(metrics_score)
context.callback_handler.end_module()
Expand All @@ -91,7 +91,7 @@ def fit(self, context: Context) -> None:
module_name,
module_kwargs,
metric_value,
self.decision_metric_name,
self.target_metric,
module.get_assets(), # retriever name / scores / predictions
module_dump_dir,
module=module if not context.is_ram_to_clear() else None,
Expand Down
129 changes: 129 additions & 0 deletions autointent/nodes/schemes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Schemes."""

import inspect
from collections.abc import Iterator
from typing import Any, Literal, TypeAlias, Union, get_type_hints

from pydantic import BaseModel, Field, RootModel

from autointent.custom_types import NodeType
from autointent.modules.abc import Module
from autointent.nodes import DecisionNodeInfo, EmbeddingNodeInfo, RegExpNodeInfo, ScoringNodeInfo


def generate_models_and_union_type_for_classes(
classes: list[type[Module]],
) -> type[BaseModel]:
"""Dynamically generates Pydantic models for class constructors and creates a union type."""
models: dict[str, type[BaseModel]] = {}

for cls in classes:
init_signature = inspect.signature(cls.from_context)
globalns = getattr(cls.from_context, "__globals__", {})
type_hints = get_type_hints(cls.from_context, globalns, None) # Resolve forward refs

fields = {"module_name": (Literal[cls.name], Field(...))}

for param_name, param in init_signature.parameters.items():
if param_name in ("self", "cls", "context"):
continue

param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042
field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...)

fields[param_name] = (list[param_type], field) # type: ignore[assignment]

model_name = f"{cls.__name__}InitModel"
models[cls.__name__] = type(
model_name,
(BaseModel,),
{
"__annotations__": {k: v[0] for k, v in fields.items()},
**{k: v[1] for k, v in fields.items()},
},
)

return Union[tuple(models.values())] # type: ignore[return-value] # noqa: UP007


DecisionSearchSpaceType: TypeAlias = generate_models_and_union_type_for_classes( # type: ignore[valid-type]
list(DecisionNodeInfo.modules_available.values())
)
DecisionMetrics: TypeAlias = Literal[tuple(DecisionNodeInfo.metrics_available.keys())] # type: ignore[valid-type]


class DecisionNodeValidator(BaseModel):
"""Search space configuration for the Decision node."""

node_type: NodeType = NodeType.decision
target_metric: DecisionMetrics
metrics: list[DecisionMetrics] | None = None
search_space: list[DecisionSearchSpaceType]


EmbeddingSearchSpaceType: TypeAlias = generate_models_and_union_type_for_classes( # type: ignore[valid-type]
list(EmbeddingNodeInfo.modules_available.values())
)
EmbeddingMetrics: TypeAlias = Literal[tuple(EmbeddingNodeInfo.metrics_available.keys())] # type: ignore[valid-type]


class EmbeddingNodeValidator(BaseModel):
"""Search space configuration for the Embedding node."""

node_type: NodeType = NodeType.embedding
target_metric: EmbeddingMetrics
metrics: list[EmbeddingMetrics] | None = None
search_space: list[EmbeddingSearchSpaceType]


ScoringSearchSpaceType: TypeAlias = generate_models_and_union_type_for_classes( # type: ignore[valid-type]
list(ScoringNodeInfo.modules_available.values())
)
ScoringMetrics: TypeAlias = Literal[tuple(ScoringNodeInfo.metrics_available.keys())] # type: ignore[valid-type]


class ScoringNodeValidator(BaseModel):
"""Search space configuration for the Scoring node."""

node_type: NodeType = NodeType.scoring
target_metric: ScoringMetrics
metrics: list[ScoringMetrics] | None = None
search_space: list[ScoringSearchSpaceType]


RegexpSearchSpaceType: TypeAlias = generate_models_and_union_type_for_classes( # type: ignore[valid-type]
list(RegExpNodeInfo.modules_available.values())
)
RegexpMetrics: TypeAlias = Literal[tuple(RegExpNodeInfo.metrics_available.keys())] # type: ignore[valid-type]


class RegexNodeValidator(BaseModel):
"""Search space configuration for the Regexp node."""

node_type: NodeType = NodeType.regexp
target_metric: RegexpMetrics
metrics: list[RegexpMetrics] | None = None
search_space: list[RegexpSearchSpaceType]


SearchSpaceTypes: TypeAlias = RegexNodeValidator | EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator


class OptimizationConfig(RootModel[list[SearchSpaceTypes]]):
"""Optimizer configuration."""

def __iter__(
self,
) -> Iterator[SearchSpaceTypes]:
"""Iterate over the root."""
return iter(self.root)

def __getitem__(self, item: int) -> SearchSpaceTypes:
"""
To get item directly from the root.
:param item: Index
:return: Item
"""
return self.root[item]
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from random import seed, shuffle


def main():
def main() -> None:
parser = ArgumentParser()
parser.add_argument("--input-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
Expand All @@ -30,7 +30,7 @@ def main():
json.dump(res, open(args.output_path, "w"), indent=4, ensure_ascii=False)


def update_counter(counter: defaultdict, labels: list[int]):
def update_counter(counter: defaultdict, labels: list[int]) -> None:
for lab in labels:
counter[lab] += 1

Expand Down
2 changes: 1 addition & 1 deletion scripts/data/make_multilabel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_multilabel_version(intent_records, config_string, seed):
return res


def main():
def main() -> None:
parser = ArgumentParser()
parser.add_argument("--input-path", type=str, required=True, help="path to intent records")
parser.add_argument("--output-path", type=str, required=True)
Expand Down
2 changes: 1 addition & 1 deletion scripts/transform_json_to_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datasets import Dataset, load_dataset, DatasetDict
from datasets import Dataset, DatasetDict, load_dataset


def transform_dataset(
Expand Down
2 changes: 1 addition & 1 deletion tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
- node_type: embedding
target_metric: scoring_accuracy
search_space:
- module_name: logreg
- module_name: logreg_embedding
cv: [2]
embedder_name:
- sentence-transformers/all-MiniLM-L6-v2
Expand Down
Empty file added tests/configs/__init__.py
Empty file.
Loading

0 comments on commit 599794b

Please sign in to comment.