Skip to content

Commit

Permalink
Refactor datasets logic (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: voorhs <[email protected]>
  • Loading branch information
truff4ut and voorhs authored Nov 25, 2024
1 parent 1cb4760 commit 4e1d43f
Show file tree
Hide file tree
Showing 55 changed files with 1,623 additions and 1,781 deletions.
4 changes: 1 addition & 3 deletions autointent/configs/optimization_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class OptimizationConfig:
"""Configuration for the logging"""
vector_index: VectorIndexConfig = field(default_factory=VectorIndexConfig)
"""Configuration for the vector index"""
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
"""Configuration for the augmentation"""
embedder: EmbedderConfig = field(default_factory=EmbedderConfig)
"""Configuration for the embedder"""

Expand All @@ -133,7 +131,7 @@ class OptimizationConfig:
"_self_",
{"override hydra/job_logging": "autointent_standard_job_logger"},
{"override hydra/help": "autointent_help"},
]
],
)


Expand Down
48 changes: 16 additions & 32 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import yaml

from autointent.configs.optimization_cli import (
AugmentationConfig,
DataConfig,
EmbedderConfig,
LoggingConfig,
VectorIndexConfig,
)

from .data_handler import DataAugmenter, DataHandler, Dataset
from .data_handler import DataHandler, Dataset
from .optimization_info import OptimizationInfo
from .utils import NumpyEncoder, load_data
from .vector_index_client import VectorIndex, VectorIndexClient
Expand Down Expand Up @@ -71,43 +70,29 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
self.embedder_config.max_length,
)

def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
def configure_data(self, config: DataConfig) -> None:
"""
Configure data handling and augmentation.
Configure data handling.
:param config: Configuration for the data handling process.
:param augmentation_config: Configuration for data augmentation. If None, no augmentation is applied.
"""
if augmentation_config is not None:
self.augmentation_config = AugmentationConfig()
augmenter = DataAugmenter(
self.augmentation_config.multilabel_generation_config,
self.augmentation_config.regex_sampling,
self.seed,
)
else:
augmenter = None

"""
self.data_handler = DataHandler(
dataset=load_data(config.train_path),
test_dataset=None if config.test_path is None else load_data(config.test_path),
random_seed=self.seed,
force_multilabel=config.force_multilabel,
augmenter=augmenter,
)

def set_datasets(
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
) -> None:
def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None:
"""
Set the datasets for training and validation.
Set the datasets for training, validation and testing.
:param train_data: Training dataset.
:param val_data: Validation dataset. If None, only training data is used.
:param dataset: Dataset.
:param force_multilabel: Whether to force multilabel classification.
"""
self.data_handler = DataHandler(
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
dataset=dataset,
force_multilabel=force_multilabel,
random_seed=self.seed,
)

def get_best_index(self) -> VectorIndex:
Expand Down Expand Up @@ -159,13 +144,12 @@ def dump(self) -> None:
with logs_path.open("w") as file:
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)

train_data, test_data = self.data_handler.dump()
train_path = logs_dir / "train_data.json"
test_path = logs_dir / "test_data.json"
with train_path.open("w") as file:
json.dump(train_data, file, indent=4, ensure_ascii=False)
with test_path.open("w") as file:
json.dump(test_data, file, indent=4, ensure_ascii=False)
# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
dataset_path = logs_dir / "dataset.json"
with dataset_path.open("w") as file:
json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False)

self._logger.info("logs and other assets are saved to %s", logs_dir)

Expand Down
8 changes: 4 additions & 4 deletions autointent/context/data_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .data_handler import DataAugmenter, DataHandler
from .schemas import Dataset
from .tags import Tag
from .data_handler import DataHandler
from .dataset import Dataset
from .schemas import Intent, Sample, Tag

__all__ = ["DataAugmenter", "DataHandler", "Dataset", "Tag"]
__all__ = ["DataHandler", "Dataset", "Intent", "Sample", "Tag"]
Loading

0 comments on commit 4e1d43f

Please sign in to comment.