diff --git a/autointent/_dataset/_dataset.py b/autointent/_dataset/_dataset.py index 9b7ccb04..fb775b07 100644 --- a/autointent/_dataset/_dataset.py +++ b/autointent/_dataset/_dataset.py @@ -1,5 +1,6 @@ """File with Dataset definition.""" +import json from collections import defaultdict from functools import cached_property from pathlib import Path @@ -48,12 +49,15 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # self.intents = intents + self._encoded_labels = False + + if self.multilabel: + self._encode_labels() + oos_split = self._create_oos_split() if oos_split is not None: self[Split.OOS] = oos_split - self._encoded_labels = False - @property def multilabel(self) -> bool: """ @@ -71,31 +75,31 @@ def n_classes(self) -> int: :return: Number of classes. """ - return self.get_n_classes(Split.TRAIN) + return len(self.intents) @classmethod - def from_json(cls, filepath: str | Path) -> "Dataset": + def from_dict(cls, mapping: dict[str, Any]) -> "Dataset": """ - Load a dataset from a JSON file. + Load a dataset from a dictionary mapping. - :param filepath: Path to the JSON file. + :param mapping: Dictionary representing the dataset. :return: Initialized Dataset object. """ - from ._reader import JsonReader + from ._reader import DictReader - return JsonReader().read(filepath) + return DictReader().read(mapping) @classmethod - def from_dict(cls, mapping: dict[str, Any]) -> "Dataset": + def from_json(cls, filepath: str | Path) -> "Dataset": """ - Load a dataset from a dictionary mapping. + Load a dataset from a JSON file. - :param mapping: Dictionary representing the dataset. + :param filepath: Path to the JSON file. :return: Initialized Dataset object. """ - from ._reader import DictReader + from ._reader import JsonReader - return DictReader().read(mapping) + return JsonReader().read(filepath) @classmethod def from_hub(cls, repo_id: str) -> "Dataset": @@ -113,34 +117,35 @@ def from_hub(cls, repo_id: str) -> "Dataset": intents=[Intent.model_validate(intent) for intent in intents], ) - def dump(self) -> dict[str, list[dict[str, Any]]]: + def to_multilabel(self) -> "Dataset": """ - Convert the dataset splits to a dictionary of lists. + Convert dataset labels to multilabel format. - :return: Dictionary containing dataset splits as lists. + :return: Self, with labels converted to multilabel. """ - return {split_name: split.to_list() for split_name, split in self.items()} + for split_name, split in self.items(): + self[split_name] = split.map(self._to_multilabel) + self._encode_labels() + return self - def encode_labels(self) -> "Dataset": + def to_dict(self) -> dict[str, list[dict[str, Any]]]: """ - Encode dataset labels into one-hot or multilabel format. + Convert the dataset splits and intents to a dictionary of lists. - :return: Self, with labels encoded. + :return: A dictionary containing dataset splits and intents as lists of dictionaries. """ - for split_name, split in self.items(): - self[split_name] = split.map(self._encode_label) - self._encoded_labels = True - return self + mapping = {split_name: split.to_list() for split_name, split in self.items()} + mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents] + return mapping - def to_multilabel(self) -> "Dataset": + def to_json(self, filepath: str | Path) -> None: """ - Convert dataset labels to multilabel format. + Save the dataset splits and intents to a JSON file. - :return: Self, with labels converted to multilabel. + :param filepath: The path to the file where the JSON data will be saved. """ - for split_name, split in self.items(): - self[split_name] = split.map(self._to_multilabel) - return self + with Path(filepath).open("w") as file: + json.dump(self.to_dict(), file, indent=4, ensure_ascii=False) def push_to_hub(self, repo_id: str, private: bool = False) -> None: """ @@ -188,6 +193,17 @@ def get_n_classes(self, split: str) -> int: classes.add(idx) return len(classes) + def _encode_labels(self) -> "Dataset": + """ + Encode dataset labels into one-hot or multilabel format. + + :return: Self, with labels encoded. + """ + for split_name, split in self.items(): + self[split_name] = split.map(self._encode_label) + self._encoded_labels = True + return self + def _is_oos(self, sample: Sample) -> bool: """ Check if a sample is out-of-scope. diff --git a/autointent/_dataset/_validation.py b/autointent/_dataset/_validation.py index ae92f0ad..ba282406 100644 --- a/autointent/_dataset/_validation.py +++ b/autointent/_dataset/_validation.py @@ -9,13 +9,22 @@ class DatasetReader(BaseModel): """ A class to represent a dataset reader for handling training, validation, and test data. - :param train: List of samples for training. + :param train: List of samples for training. Defaults to an empty list. + :param train_0: List of samples for scoring module training. Defaults to an empty list. + :param train_1: List of samples for decision module training. Defaults to an empty list. :param validation: List of samples for validation. Defaults to an empty list. + :param validation_0: List of samples for scoring module validation. Defaults to an empty list. + :param validation_1: List of samples for decision module validation. Defaults to an empty list. :param test: List of samples for testing. Defaults to an empty list. :param intents: List of intents associated with the dataset. """ - train: list[Sample] + train: list[Sample] = [] + train_0: list[Sample] = [] + train_1: list[Sample] = [] + validation: list[Sample] = [] + validation_0: list[Sample] = [] + validation_1: list[Sample] = [] test: list[Sample] = [] intents: list[Intent] = [] @@ -27,19 +36,75 @@ def validate_dataset(self) -> "DatasetReader": :raises ValueError: If intents or samples are not properly validated. :return: The validated DatasetReader instance. """ - self._validate_intents() - for split in [self.train, self.test]: + if self.train and (self.train_0 or self.train_1): + message = "If `train` is provided, `train_0` and `train_1` should be empty." + raise ValueError(message) + if not self.train and (not self.train_0 or not self.train_1): + message = "Both `train_0` and `train_1` must be provided if `train` is empty." + raise ValueError(message) + + if self.validation and (self.validation_0 or self.validation_1): + message = "If `validation` is provided, `validation_0` and `validation_1` should be empty." + raise ValueError(message) + if not self.validation: + message = "Either both `validation_0` and `validation_1` must be provided, or neither of them." + if not self.validation_0 and self.validation_1: + raise ValueError(message) + if self.validation_0 and not self.validation_1: + raise ValueError(message) + + splits = [ + self.train, self.train_0, self.train_1, self.validation, self.validation_0, self.validation_1, self.test, + ] + splits = [split for split in splits if split] + + n_classes = [self._get_n_classes(split) for split in splits] + if len(set(n_classes)) != 1: + message = ( + f"Mismatch in number of classes across splits. Found class counts: {n_classes}. " + "Ensure all splits have the same number of classes." + ) + raise ValueError(message) + if not n_classes[0]: + message = ( + "Number of classes is zero or undefined. " + "Ensure at least one class is present in the splits." + ) + raise ValueError(message) + + self._validate_intents(n_classes[0]) + + for split in splits: self._validate_split(split) return self - def _validate_intents(self) -> "DatasetReader": + def _get_n_classes(self, split: list[Sample]) -> int: + """ + Get the number of classes in a dataset split. + + :param split: List of samples in a dataset split (train, validation, or test). + :return: The number of classes. + """ + classes = set() + for sample in split: + match sample.label: + case int(): + classes.add(sample.label) + case list(): + for label in sample.label: + classes.add(label) + return len(classes) + + def _validate_intents(self, n_classes: int) -> "DatasetReader": """ Validate the intents by checking their IDs for sequential order. + :param n_classes: The number of classes in the dataset. :raises ValueError: If intent IDs are not sequential starting from 0. :return: The DatasetReader instance after validation. """ if not self.intents: + self.intents = [Intent(id=idx) for idx in range(n_classes)] return self self.intents = sorted(self.intents, key=lambda intent: intent.id) intent_ids = [intent.id for intent in self.intents] @@ -59,8 +124,6 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader": :raises ValueError: If a sample references an invalid or non-existent intent ID. :return: The DatasetReader instance after validation. """ - if not split or not self.intents: - return self intent_ids = {intent.id for intent in self.intents} for sample in split: message = ( diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 452d3bc4..cfc2a140 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -137,9 +137,7 @@ def dump(self) -> None: # self._logger.info(make_report(optimization_results, nodes=nodes)) # dump train and test data splits - dataset_path = logs_dir / "dataset.json" - with dataset_path.open("w") as file: - json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False) + self.data_handler.dump(logs_dir / "dataset.json") self._logger.info("logs and other assets are saved to %s", logs_dir) diff --git a/autointent/context/_utils.py b/autointent/context/_utils.py index 5f8d7e3b..fd2ac99f 100644 --- a/autointent/context/_utils.py +++ b/autointent/context/_utils.py @@ -56,7 +56,7 @@ def load_data(filepath: str | Path) -> Dataset: if filepath == "default-multiclass": return Dataset.from_hub("AutoIntent/clinc150_subset") if filepath == "default-multilabel": - return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel().encode_labels() + return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel() if not Path(filepath).exists(): return Dataset.from_hub(str(filepath)) return Dataset.from_json(filepath) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index d3ff2492..491fd312 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -1,8 +1,10 @@ """Data Handler file.""" import logging -from typing import Any, TypedDict, cast +from pathlib import Path +from typing import TypedDict, cast +from datasets import concatenate_datasets from transformers import set_seed from autointent import Dataset @@ -45,8 +47,6 @@ def __init__( self.dataset = dataset if force_multilabel: self.dataset = self.dataset.to_multilabel() - if self.dataset.multilabel: - self.dataset = self.dataset.encode_labels() self.n_classes = self.dataset.n_classes @@ -183,23 +183,49 @@ def has_oos_samples(self) -> bool: """ return any(split.startswith(Split.OOS) for split in self.dataset) - def dump(self) -> dict[str, list[dict[str, Any]]]: + def dump(self, filepath: str | Path) -> None: """ - Dump the dataset splits. + Save the dataset splits and intents to a JSON file. - :return: Dataset dump. + :param filepath: The path to the file where the JSON data will be saved. """ - return self.dataset.dump() + self.dataset.to_json(filepath) def _split(self, random_seed: int) -> None: + has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset) + has_test_split = any(split.startswith(Split.TEST) for split in self.dataset) + + if Split.TRAIN in self.dataset: + self._split_train(random_seed) + if Split.TEST not in self.dataset: - self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset( - self.dataset, - split=Split.TRAIN, - test_size=0.2, - random_seed=random_seed, - ) + test_size = 0.1 if has_validation_split else 0.2 + self._split_test(test_size, random_seed) + + if not has_validation_split: + if not has_test_split: + self._split_validation_from_test(random_seed) + self._split_validation(random_seed) + else: + self._split_validation_from_train(random_seed) + elif Split.VALIDATION in self.dataset: + self._split_validation(random_seed) + + if self.has_oos_samples(): + self._split_oos(random_seed) + for split in self.dataset: + if split.startswith(Split.OOS): + continue + n_classes_split = self.dataset.get_n_classes(split) + if n_classes_split != self.n_classes: + message = ( + f"Number of classes in split '{split}' doesn't match initial number of classes " + f"({n_classes_split} != {self.n_classes})" + ) + raise ValueError(message) + + def _split_train(self, random_seed: int) -> None: self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset( self.dataset, split=Split.TRAIN, @@ -208,6 +234,24 @@ def _split(self, random_seed: int) -> None: ) self.dataset.pop(Split.TRAIN) + def _split_validation(self, random_seed: int) -> None: + self.dataset[f"{Split.VALIDATION}_0"], self.dataset[f"{Split.VALIDATION}_1"] = split_dataset( + self.dataset, + split=Split.VALIDATION, + test_size=0.5, + random_seed=random_seed, + ) + self.dataset.pop(Split.VALIDATION) + + def _split_validation_from_test(self, random_seed: int) -> None: + self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset( + self.dataset, + split=Split.TEST, + test_size=0.5, + random_seed=random_seed, + ) + + def _split_validation_from_train(self, random_seed: int) -> None: for idx in range(2): self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset( self.dataset, @@ -216,34 +260,42 @@ def _split(self, random_seed: int) -> None: random_seed=random_seed, ) - if self.has_oos_samples(): - self.dataset[f"{Split.OOS}_0"], self.dataset[f"{Split.OOS}_1"] = ( - self.dataset[Split.OOS] - .train_test_split( - test_size=0.2, - shuffle=True, - seed=random_seed, - ) - .values() + def _split_test(self, test_size: float, random_seed: int) -> None: + self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset( + self.dataset, + split=f"{Split.TRAIN}_0", + test_size=test_size, + random_seed=random_seed, + ) + self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset( + self.dataset, + split=f"{Split.TRAIN}_1", + test_size=test_size, + random_seed=random_seed, + ) + self.dataset[Split.TEST] = concatenate_datasets( + [self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]], + ) + self.dataset.pop(f"{Split.TEST}_0") + self.dataset.pop(f"{Split.TEST}_1") + + def _split_oos(self, random_seed: int) -> None: + self.dataset[f"{Split.OOS}_0"], self.dataset[f"{Split.OOS}_1"] = ( + self.dataset[Split.OOS] + .train_test_split( + test_size=0.2, + shuffle=True, + seed=random_seed, ) - self.dataset[f"{Split.OOS}_1"], self.dataset[f"{Split.OOS}_2"] = ( - self.dataset[f"{Split.OOS}_1"] - .train_test_split( - test_size=0.5, - shuffle=True, - seed=random_seed, - ) - .values() + .values() + ) + self.dataset[f"{Split.OOS}_1"], self.dataset[f"{Split.OOS}_2"] = ( + self.dataset[f"{Split.OOS}_1"] + .train_test_split( + test_size=0.5, + shuffle=True, + seed=random_seed, ) - self.dataset.pop(Split.OOS) - - for split in self.dataset: - if split.startswith(Split.OOS): - continue - n_classes_split = self.dataset.get_n_classes(split) - if n_classes_split != self.n_classes: - message = ( - f"Number of classes in split '{split}' doesn't match initial number of classes " - f"({n_classes_split} != {self.n_classes})" - ) - raise ValueError(message) + .values() + ) + self.dataset.pop(Split.OOS) diff --git a/autointent/modules/scoring/_description/description.py b/autointent/modules/scoring/_description/description.py index 3d2a0bd3..6fa68c6b 100644 --- a/autointent/modules/scoring/_description/description.py +++ b/autointent/modules/scoring/_description/description.py @@ -10,7 +10,6 @@ from sklearn.metrics.pairwise import cosine_similarity from autointent import Context, Embedder -from autointent.context.vector_index_client import VectorIndex, VectorIndexClient from autointent.custom_types import LabelType from autointent.modules.abc import ScoringModule @@ -18,7 +17,6 @@ class DescriptionScorerDumpMetadata(TypedDict): """Metadata for dumping the state of a DescriptionScorer.""" - db_dir: str n_classes: int multilabel: bool batch_size: int @@ -33,21 +31,13 @@ class DescriptionScorer(ScoringModule): between the two, using either cosine similarity and softmax. :ivar weights_file_name: Filename for saving the description vectors (`description_vectors.npy`). - :ivar embedder: The embedder used to generate embeddings for utterances and descriptions. - :ivar precomputed_embeddings: Flag indicating whether precomputed embeddings are used. :ivar embedding_model_subdir: Directory for storing the embedder's model files. - :ivar _vector_index: Internal vector index used when embeddings are precomputed. - :ivar db_dir: Directory path where the vector database is stored. :ivar name: Name of the scorer, defaults to "description". """ weights_file_name: str = "description_vectors.npy" - embedder: Embedder - precomputed_embeddings: bool = False embedding_model_subdir: str = "embedding_model" - _vector_index: VectorIndex - db_dir: str name = "description" def __init__( @@ -91,21 +81,12 @@ def from_context( :param embedder_name: Name of the embedder model. If None, the best embedder is used. :return: Initialized DescriptionScorer instance. """ - if embedder_name is None: - embedder_name = context.optimization_info.get_best_embedder() - precomputed_embeddings = True - else: - precomputed_embeddings = context.vector_index_client.exists(embedder_name) - - instance = cls( + return cls( temperature=temperature, embedder_device=context.get_device(), - embedder_name=embedder_name, + embedder_name=embedder_name if embedder_name is not None else context.optimization_info.get_best_embedder(), embedder_use_cache=context.get_use_cache(), ) - instance.precomputed_embeddings = precomputed_embeddings - instance.db_dir = str(context.get_db_dir()) - return instance def get_embedder_name(self) -> str: """ @@ -136,30 +117,13 @@ def fit( self.n_classes = len(set(labels)) self.multilabel = False - if self.precomputed_embeddings: - # this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization - vector_index_client = VectorIndexClient( - self.embedder_device, - self.db_dir, - self.batch_size, - self.max_length, - self.embedder_use_cache, - ) - vector_index = vector_index_client.get_index(self.embedder_name) - features = vector_index.get_all_embeddings() - if len(features) != len(utterances): - msg = "Vector index mismatches provided utterances" - raise ValueError(msg) - embedder = vector_index.embedder - else: - embedder = Embedder( - device=self.embedder_device, - model_name=self.embedder_name, - batch_size=self.batch_size, - max_length=self.max_length, - use_cache=self.embedder_use_cache, - ) - features = embedder.embed(utterances) + embedder = Embedder( + device=self.embedder_device, + model_name=self.embedder_name, + batch_size=self.batch_size, + max_length=self.max_length, + use_cache=self.embedder_use_cache, + ) if any(description is None for description in descriptions): error_text = ( @@ -169,7 +133,7 @@ def fit( raise ValueError(error_text) self.description_vectors = embedder.embed([desc for desc in descriptions if desc]) - self.embedder = embedder + self._embedder = embedder def predict(self, utterances: list[str]) -> NDArray[np.float64]: """ @@ -178,7 +142,7 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]: :param utterances: List of utterances to score. :return: Array of probabilities for each utterance. """ - utterance_vectors = self.embedder.embed(utterances) + utterance_vectors = self._embedder.embed(utterances) similarities: NDArray[np.float64] = cosine_similarity(utterance_vectors, self.description_vectors) if self.multilabel: @@ -189,7 +153,7 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]: def clear_cache(self) -> None: """Clear cached data in memory used by the embedder.""" - self.embedder.clear_ram() + self._embedder.clear_ram() def dump(self, path: str) -> None: """ @@ -198,7 +162,6 @@ def dump(self, path: str) -> None: :param path: Path to the directory where assets will be dumped. """ self.metadata = DescriptionScorerDumpMetadata( - db_dir=str(self.db_dir), n_classes=self.n_classes, multilabel=self.multilabel, batch_size=self.batch_size, @@ -210,7 +173,7 @@ def dump(self, path: str) -> None: json.dump(self.metadata, file, indent=4) np.save(dump_dir / self.weights_file_name, self.description_vectors) - self.embedder.dump(dump_dir / self.embedding_model_subdir) + self._embedder.dump(dump_dir / self.embedding_model_subdir) def load(self, path: str) -> None: """ @@ -229,7 +192,7 @@ def load(self, path: str) -> None: self.multilabel = self.metadata["multilabel"] embedder_dir = dump_dir / self.embedding_model_subdir - self.embedder = Embedder( + self._embedder = Embedder( device=self.embedder_device, model_name=embedder_dir, batch_size=self.metadata["batch_size"], diff --git a/autointent/modules/scoring/_linear.py b/autointent/modules/scoring/_linear.py index 74b630f0..fd366d84 100644 --- a/autointent/modules/scoring/_linear.py +++ b/autointent/modules/scoring/_linear.py @@ -11,7 +11,6 @@ from sklearn.multioutput import MultiOutputClassifier from autointent import Context, Embedder -from autointent.context.vector_index_client import VectorIndexClient from autointent.custom_types import BaseMetadataDict, LabelType from autointent.modules.abc import ScoringModule @@ -39,8 +38,6 @@ class LinearScorer(ScoringModule): :ivar classifier_file_name: Filename for saving the classifier to disk. :ivar embedding_model_subdir: Directory for saving the embedding model. - :ivar precomputed_embeddings: Flag indicating if embeddings are precomputed. - :ivar db_dir: Path to the database directory. :ivar name: Name of the scorer, defaults to "linear". Example @@ -67,8 +64,6 @@ class LinearScorer(ScoringModule): classifier_file_name: str = "classifier.joblib" embedding_model_subdir: str = "embedding_model" - precomputed_embeddings: bool = False - db_dir: str name = "linear" def __init__( @@ -116,23 +111,14 @@ def from_context( :param embedder_name: Name of the embedder, or None to use the best embedder. :return: Initialized LinearScorer instance. """ - if embedder_name is None: - embedder_name = context.optimization_info.get_best_embedder() - precomputed_embeddings = True - else: - precomputed_embeddings = context.vector_index_client.exists(embedder_name) - - instance = cls( - embedder_name=embedder_name, + return cls( + embedder_name=embedder_name if embedder_name else context.optimization_info.get_best_embedder(), embedder_device=context.get_device(), seed=context.seed, batch_size=context.get_batch_size(), max_length=context.get_max_length(), embedder_use_cache=context.get_use_cache(), ) - instance.precomputed_embeddings = precomputed_embeddings - instance.db_dir = str(context.get_db_dir()) - return instance def get_embedder_name(self) -> str: """ @@ -156,30 +142,14 @@ def fit( """ self._multilabel = isinstance(labels[0], list) - if self.precomputed_embeddings: - # this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization - vector_index_client = VectorIndexClient( - self.embedder_device, - self.db_dir, - self.batch_size, - self.max_length, - self.embedder_use_cache, - ) - vector_index = vector_index_client.get_index(self.embedder_name) - features = vector_index.get_all_embeddings() - if len(features) != len(utterances): - msg = "Vector index mismatches provided utterances" - raise ValueError(msg) - embedder = vector_index.embedder - else: - embedder = Embedder( - device=self.embedder_device, - model_name=self.embedder_name, - batch_size=self.batch_size, - max_length=self.max_length, - use_cache=self.embedder_use_cache, - ) - features = embedder.embed(utterances) + embedder = Embedder( + device=self.embedder_device, + model_name=self.embedder_name, + batch_size=self.batch_size, + max_length=self.max_length, + use_cache=self.embedder_use_cache, + ) + features = embedder.embed(utterances) if self._multilabel: base_clf = LogisticRegression() diff --git a/tests/context/datahandler/test_data_handler.py b/tests/context/datahandler/test_data_handler.py index 44d16e83..38f898af 100644 --- a/tests/context/datahandler/test_data_handler.py +++ b/tests/context/datahandler/test_data_handler.py @@ -91,26 +91,6 @@ def test_data_handler_multilabel_mode(sample_multilabel_data): assert handler.test_labels() == [[1, 0], [0, 1]] -def test_dump_method(sample_multiclass_data): - handler = DataHandler(dataset=Dataset.from_dict(sample_multiclass_data), random_seed=42) - - dump = handler.dump() - - for split in ["train_0", "validation_0", "train_1", "validation_1", "test"]: - assert split in dump - - assert dump["train_0"] == [ - {"utterance": "hello", "label": 0}, - {"utterance": "bye", "label": 1}, - {"utterance": "hi", "label": 0}, - {"utterance": "take care", "label": 1}, - ] - assert dump["test"] == [ - {"utterance": "greetings", "label": 0}, - {"utterance": "farewell", "label": 1}, - ] - - @pytest.mark.skip("All data validations will be refactored later") def test_error_handling( sample_multiclass_intent_records, diff --git a/tests/context/datahandler/test_stratificaiton.py b/tests/context/datahandler/test_stratificaiton.py index 2bf4cd0b..5773da2f 100644 --- a/tests/context/datahandler/test_stratificaiton.py +++ b/tests/context/datahandler/test_stratificaiton.py @@ -18,7 +18,7 @@ def test_train_test_split(dataset): def test_multilabel_train_test_split(dataset): - dataset = dataset.to_multilabel().encode_labels() + dataset = dataset.to_multilabel() dataset[Split.TRAIN], dataset[Split.TEST] = split_dataset( dataset, split=Split.TRAIN,