Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor vector index caching logic #80

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions autointent/_dataset/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""File with Dataset definition."""

import json
from collections import defaultdict
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -48,12 +49,15 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #

self.intents = intents

self._encoded_labels = False

if self.multilabel:
self._encode_labels()

oos_split = self._create_oos_split()
if oos_split is not None:
self[Split.OOS] = oos_split

self._encoded_labels = False

@property
def multilabel(self) -> bool:
"""
Expand All @@ -71,31 +75,31 @@ def n_classes(self) -> int:

:return: Number of classes.
"""
return self.get_n_classes(Split.TRAIN)
return len(self.intents)

@classmethod
def from_json(cls, filepath: str | Path) -> "Dataset":
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
"""
Load a dataset from a JSON file.
Load a dataset from a dictionary mapping.

:param filepath: Path to the JSON file.
:param mapping: Dictionary representing the dataset.
:return: Initialized Dataset object.
"""
from ._reader import JsonReader
from ._reader import DictReader

return JsonReader().read(filepath)
return DictReader().read(mapping)

@classmethod
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
def from_json(cls, filepath: str | Path) -> "Dataset":
"""
Load a dataset from a dictionary mapping.
Load a dataset from a JSON file.

:param mapping: Dictionary representing the dataset.
:param filepath: Path to the JSON file.
:return: Initialized Dataset object.
"""
from ._reader import DictReader
from ._reader import JsonReader

return DictReader().read(mapping)
return JsonReader().read(filepath)

@classmethod
def from_hub(cls, repo_id: str) -> "Dataset":
Expand All @@ -113,34 +117,35 @@ def from_hub(cls, repo_id: str) -> "Dataset":
intents=[Intent.model_validate(intent) for intent in intents],
)

def dump(self) -> dict[str, list[dict[str, Any]]]:
def to_multilabel(self) -> "Dataset":
"""
Convert the dataset splits to a dictionary of lists.
Convert dataset labels to multilabel format.

:return: Dictionary containing dataset splits as lists.
:return: Self, with labels converted to multilabel.
"""
return {split_name: split.to_list() for split_name, split in self.items()}
for split_name, split in self.items():
self[split_name] = split.map(self._to_multilabel)
self._encode_labels()
return self

def encode_labels(self) -> "Dataset":
def to_dict(self) -> dict[str, list[dict[str, Any]]]:
"""
Encode dataset labels into one-hot or multilabel format.
Convert the dataset splits and intents to a dictionary of lists.

:return: Self, with labels encoded.
:return: A dictionary containing dataset splits and intents as lists of dictionaries.
"""
for split_name, split in self.items():
self[split_name] = split.map(self._encode_label)
self._encoded_labels = True
return self
mapping = {split_name: split.to_list() for split_name, split in self.items()}
mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents]
return mapping

def to_multilabel(self) -> "Dataset":
def to_json(self, filepath: str | Path) -> None:
"""
Convert dataset labels to multilabel format.
Save the dataset splits and intents to a JSON file.

:return: Self, with labels converted to multilabel.
:param filepath: The path to the file where the JSON data will be saved.
"""
for split_name, split in self.items():
self[split_name] = split.map(self._to_multilabel)
return self
with Path(filepath).open("w") as file:
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)

def push_to_hub(self, repo_id: str, private: bool = False) -> None:
"""
Expand Down Expand Up @@ -188,6 +193,17 @@ def get_n_classes(self, split: str) -> int:
classes.add(idx)
return len(classes)

def _encode_labels(self) -> "Dataset":
"""
Encode dataset labels into one-hot or multilabel format.

:return: Self, with labels encoded.
"""
for split_name, split in self.items():
self[split_name] = split.map(self._encode_label)
self._encoded_labels = True
return self

def _is_oos(self, sample: Sample) -> bool:
"""
Check if a sample is out-of-scope.
Expand Down
77 changes: 70 additions & 7 deletions autointent/_dataset/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@ class DatasetReader(BaseModel):
"""
A class to represent a dataset reader for handling training, validation, and test data.

:param train: List of samples for training.
:param train: List of samples for training. Defaults to an empty list.
:param train_0: List of samples for scoring module training. Defaults to an empty list.
:param train_1: List of samples for decision module training. Defaults to an empty list.
:param validation: List of samples for validation. Defaults to an empty list.
:param validation_0: List of samples for scoring module validation. Defaults to an empty list.
:param validation_1: List of samples for decision module validation. Defaults to an empty list.
:param test: List of samples for testing. Defaults to an empty list.
:param intents: List of intents associated with the dataset.
"""

train: list[Sample]
train: list[Sample] = []
train_0: list[Sample] = []
train_1: list[Sample] = []
validation: list[Sample] = []
validation_0: list[Sample] = []
validation_1: list[Sample] = []
test: list[Sample] = []
intents: list[Intent] = []

Expand All @@ -27,19 +36,75 @@ def validate_dataset(self) -> "DatasetReader":
:raises ValueError: If intents or samples are not properly validated.
:return: The validated DatasetReader instance.
"""
self._validate_intents()
for split in [self.train, self.test]:
if self.train and (self.train_0 or self.train_1):
message = "If `train` is provided, `train_0` and `train_1` should be empty."
raise ValueError(message)
if not self.train and (not self.train_0 or not self.train_1):
message = "Both `train_0` and `train_1` must be provided if `train` is empty."
raise ValueError(message)

if self.validation and (self.validation_0 or self.validation_1):
message = "If `validation` is provided, `validation_0` and `validation_1` should be empty."
raise ValueError(message)
if not self.validation:
message = "Either both `validation_0` and `validation_1` must be provided, or neither of them."
if not self.validation_0 and self.validation_1:
raise ValueError(message)
if self.validation_0 and not self.validation_1:
raise ValueError(message)

splits = [
self.train, self.train_0, self.train_1, self.validation, self.validation_0, self.validation_1, self.test,
]
splits = [split for split in splits if split]

n_classes = [self._get_n_classes(split) for split in splits]
if len(set(n_classes)) != 1:
message = (
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
"Ensure all splits have the same number of classes."
)
Comment on lines +62 to +66
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

здесь можно сделать более подробное сообщение в духе перечислить сплиты и сколько классов в каждом найдено

raise ValueError(message)
if not n_classes[0]:
message = (
"Number of classes is zero or undefined. "
"Ensure at least one class is present in the splits."
)
raise ValueError(message)

self._validate_intents(n_classes[0])

for split in splits:
self._validate_split(split)
return self

def _validate_intents(self) -> "DatasetReader":
def _get_n_classes(self, split: list[Sample]) -> int:
"""
Get the number of classes in a dataset split.

:param split: List of samples in a dataset split (train, validation, or test).
:return: The number of classes.
"""
classes = set()
for sample in split:
match sample.label:
case int():
classes.add(sample.label)
case list():
for label in sample.label:
classes.add(label)
return len(classes)

def _validate_intents(self, n_classes: int) -> "DatasetReader":
"""
Validate the intents by checking their IDs for sequential order.

:param n_classes: The number of classes in the dataset.
:raises ValueError: If intent IDs are not sequential starting from 0.
:return: The DatasetReader instance after validation.
"""
if not self.intents:
self.intents = [Intent(id=idx) for idx in range(n_classes)]
return self
self.intents = sorted(self.intents, key=lambda intent: intent.id)
intent_ids = [intent.id for intent in self.intents]
Expand All @@ -59,8 +124,6 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
:raises ValueError: If a sample references an invalid or non-existent intent ID.
:return: The DatasetReader instance after validation.
"""
if not split or not self.intents:
return self
intent_ids = {intent.id for intent in self.intents}
for sample in split:
message = (
Expand Down
4 changes: 1 addition & 3 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def dump(self) -> None:
# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
dataset_path = logs_dir / "dataset.json"
with dataset_path.open("w") as file:
json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False)
self.data_handler.dump(logs_dir / "dataset.json")

self._logger.info("logs and other assets are saved to %s", logs_dir)

Expand Down
2 changes: 1 addition & 1 deletion autointent/context/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def load_data(filepath: str | Path) -> Dataset:
if filepath == "default-multiclass":
return Dataset.from_hub("AutoIntent/clinc150_subset")
if filepath == "default-multilabel":
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel().encode_labels()
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel()
if not Path(filepath).exists():
return Dataset.from_hub(str(filepath))
return Dataset.from_json(filepath)
Loading
Loading