diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..ba7ce47e 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -9,7 +9,16 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value +from fast_llm.utils import ( + Assert, + Tag, + Registry, + get_type_name, + header, + log, + pop_nested_dict_value, + set_nested_dict_value, +) logger = logging.getLogger(__name__) @@ -634,17 +643,17 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: value = str(value) return value - def to_copy[ - T - ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: + def to_copy[T]( + self: T, + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + ) -> T: return self.from_dict(self, *updates, strict=strict) def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) - def to_logs[ - T - ]( + def to_logs[T]( self, verbose: int | None = FieldVerboseLevel.core, log_fn: typing.Callable[[str], T] = logger.info, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..2f1a24b3 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,4 +1,3 @@ -import os import pathlib import typing @@ -8,6 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import HFProcessorConfig, ProcessorsConfig + if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator MEMMAP_DTYPES = { @@ -77,39 +79,6 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class -class DatasetPreparatorDistributedConfig(Config): - # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig - - default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1)) - default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0)) - world_size: int = Field( - default=None, - desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", - hint=FieldHint.expert, - valid=check_field(Assert.gt, 0), - ) - rank: int = Field( - default=None, - desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.", - hint=FieldHint.expert, - valid=check_field(Assert.geq, 0), - ) - backend: str = Field( - default="gloo", - desc="Distributed backend to use.", - hint=FieldHint.optional, - ) - - def _validate(self) -> None: - if self.world_size is None: - self.world_size = self.default_world_size - if self.rank is None: - self.rank = self.default_rank - super()._validate() - Assert.in_range(self.rank, 0, self.world_size) - - @config_class() class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" @@ -165,12 +134,23 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) + # TODO: Add desc and hint. + processors: ProcessorsConfig = Field(default=ProcessorsConfig) + def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() + # Propagete datasaet field name and workers count if not set in processors' configs. + for processor_config_field_name in self.processors.get_processor_types_map().keys(): + config: HFProcessorConfig = getattr(self.processors, processor_config_field_name) + if config.field is None: + config.field = self.dataset.field + if config.num_proc is None: + config.num_proc = self.tokenize_workers + @classmethod def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator diff --git a/fast_llm/data/preparator/gpt_memmap/distributed_config.py b/fast_llm/data/preparator/gpt_memmap/distributed_config.py new file mode 100644 index 00000000..7c653a4b --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/distributed_config.py @@ -0,0 +1,38 @@ +import os +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + + +@config_class +class DatasetPreparatorDistributedConfig(Config): + # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig + + default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1)) + default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0)) + world_size: int = Field( + default=None, + desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.gt, 0), + ) + rank: int = Field( + default=None, + desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.geq, 0), + ) + backend: str = Field( + default="gloo", + desc="Distributed backend to use.", + hint=FieldHint.optional, + ) + + def _validate(self) -> None: + if self.world_size is None: + self.world_size = self.default_world_size + if self.rank is None: + self.rank = self.default_rank + super()._validate() + Assert.in_range(self.rank, 0, self.world_size) \ No newline at end of file diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py new file mode 100644 index 00000000..e8743817 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py @@ -0,0 +1,172 @@ +import abc +import datasets +import typing + +from fast_llm.config import Config, Configurable, Field, FieldUpdate, config_class +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig + + +# TODO: Add desc and hint to all fields. + + +@config_class +class HFProcessorConfig(Config): + use_processor: bool = Field(default=True) + human_readable_name: str = Field(default="") + batch_size: int | None = Field(default=None) + num_proc: int | None = Field(default=None) + field: str | None = Field(default=None) + + +class HFProcessor[ConfigType: HFProcessorConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[HFProcessorConfig]] = HFProcessorConfig + + def __init__(self, config: ConfigType, distributed_config: DatasetPreparatorDistributedConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + self._distributed_config = distributed_config + + @abc.abstractmethod + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + raise NotImplementedError + + +@config_class +class DocLengthFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Document Length Filter") + min_length_chars: int = Field(default=0) + max_length_chars: int = Field(default=1_000_000) + + +class DocLengthFilterProcessor[ConfigType: DocLengthFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[DocLengthFilterProcessorConfig]] = DocLengthFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_doc_length_filter_processor + + return apply_doc_length_filter_processor(self._config, dataset) + + +@config_class +class NGramRepetitionFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="N-Gram Repetition Filter") + n: int = Field(default=5) + max_repetitions: int = Field(default=32) + + +class NGramRepetitionFilterProcessor[ConfigType: NGramRepetitionFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[NGramRepetitionFilterProcessorConfig]] = NGramRepetitionFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import ( + apply_ngram_repetition_filter_processor, + ) + + return apply_ngram_repetition_filter_processor(self._config, dataset) + + +@config_class +class FrequencyBasedFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Frequency-Based Filter") + max_single_word_ratio: float = Field(default=0.3) + max_top_two_word_ratio: float = Field(default=0.5) + + +class FrequencyBasedFilterProcessor[ConfigType: FrequencyBasedFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[FrequencyBasedFilterProcessorConfig]] = FrequencyBasedFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_frequency_based_filter_processor + + return apply_frequency_based_filter_processor(self._config, dataset) + + +@config_class +class BinaryContentFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Binary Content Filter") + max_bin_ratio: float = Field(default=0.5) + + +class BinaryContentFilterProcessor[ConfigType: BinaryContentFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[BinaryContentFilterProcessorConfig]] = BinaryContentFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_binary_content_filter_processor + + return apply_binary_content_filter_processor(self._config, dataset) + + +@config_class +class NumericalContentFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Numerical Content Filter") + max_numeric_token_ratio: float = Field(default=0.5) + + +class NumericalContentFilterProcessor[ConfigType: NumericalContentFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[NumericalContentFilterProcessorConfig]] = NumericalContentFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import ( + apply_numerical_content_filter_processor, + ) + + return apply_numerical_content_filter_processor(self._config, dataset) + + +@config_class +class PiiRedactionProcessorConfig(HFProcessorConfig): + use_processor: bool = FieldUpdate(default=False) + human_readable_name: str | None = FieldUpdate(default="PII Redaction Processor") + # TODO: make enum + redaction_method: str = Field(default="remove") # Options: 'remove', 'mask' + + +class PiiRedactionProcessor[ConfigType: PiiRedactionProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[PiiRedactionProcessorConfig]] = PiiRedactionProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_pii_redaction_processor + + return apply_pii_redaction_processor(self._config, self._distributed_config, dataset) + + +@config_class +class MalwareRemovalProcessorConfig(HFProcessorConfig): + use_processor: bool = FieldUpdate(default=False) + human_readable_name: str | None = FieldUpdate(default="Malware Removal Processor") + + +class MalwareRemovalProcessor[ConfigType: MalwareRemovalProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[MalwareRemovalProcessorConfig]] = MalwareRemovalProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_malware_removal_processor + + return apply_malware_removal_processor(self._config, dataset) + + +@config_class +class ProcessorsConfig(Config): + doc_length: DocLengthFilterProcessorConfig = Field(default=DocLengthFilterProcessorConfig) + n_gramms: NGramRepetitionFilterProcessorConfig = Field(default=NGramRepetitionFilterProcessorConfig) + frequency: FrequencyBasedFilterProcessorConfig = Field(default=FrequencyBasedFilterProcessorConfig) + binary: BinaryContentFilterProcessorConfig = Field(default=BinaryContentFilterProcessorConfig) + numerical: NumericalContentFilterProcessorConfig = Field(default=NumericalContentFilterProcessorConfig) + pii: PiiRedactionProcessorConfig = Field(default=PiiRedactionProcessorConfig) + malware: MalwareRemovalProcessorConfig = Field(default=MalwareRemovalProcessorConfig) + + # TODO: add validation so all steps are actual field names + order: list[str] = Field( + default_factory=lambda: ["doc_length", "n_gramms", "frequency", "binary", "numerical", "pii", "malware"] + ) + + def get_processor_types_map(self): + return { + "doc_length": DocLengthFilterProcessor, + "n_gramms": NGramRepetitionFilterProcessor, + "frequency": FrequencyBasedFilterProcessor, + "binary": BinaryContentFilterProcessor, + "numerical": NumericalContentFilterProcessor, + "pii": PiiRedactionProcessor, + "malware": MalwareRemovalProcessor, + } diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py new file mode 100644 index 00000000..79e8cc20 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py @@ -0,0 +1,86 @@ +import datasets +import pathlib +import time +import typing + +import torch +import torch.distributed + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig + + +class ProcessorMetricsLogger: + def __init__( + self, distributed_config: DatasetPreparatorDistributedConfig, field: str, num_proc: int, batch_size: int + ): + self.start_time = None + self.distributed_config = distributed_config + self.field = field + self.num_proc = num_proc + self.batch_size = batch_size + self.local_times = [] + self.local_doc_lengths = [] + self.local_chars = [] + + def start(self): + self.start_time = time.time() + + def stop(self, dataset: datasets.Dataset, step_name: str): + # TODO: seems generated nonsense, rewrite manually + elapsed_time = time.time() - self.start_time + num_rows = len(dataset) + + def compute_doc_lengths(batch): + return {"doc_lengths": [len(doc) for doc in batch[self.field]]} + + doc_lengths = dataset.map( + compute_doc_lengths, batched=True, batch_size=self.batch_size, num_proc=self.num_proc + ) + doc_lengths = sum(doc_lengths["doc_lengths"], []) + num_chars = sum(doc_lengths) + + self.local_times.append(elapsed_time) + self.local_doc_lengths.extend(doc_lengths) + self.local_chars.append(num_chars) + + local_stats = torch.tensor( + [num_rows, num_chars, min(doc_lengths, default=0), max(doc_lengths, default=0)], dtype=torch.long + ) + all_stats = [ + torch.zeros_like(local_stats) for _ in range(torch.distributed.get_world_size(self.process_group)) + ] + + if torch.distributed.is_initialized(): + torch.distributed.all_gather(all_stats, local_stats, group=self.process_group) + + if self.rank == 0: + all_times = torch.tensor(self.local_times) + all_chars = torch.tensor(self.local_chars) + min_time, max_time, avg_time = all_times.min().item(), all_times.max().item(), all_times.mean().item() + min_chars, max_chars, total_chars = all_chars.min().item(), all_chars.max().item(), all_chars.sum().item() + min_doc_length = min(stat[2].item() for stat in all_stats) + max_doc_length = max(stat[3].item() for stat in all_stats) + total_rows = sum(stat[0].item() for stat in all_stats) + + return { + "step_name": step_name, + "elapsed_time": {"min": min_time, "max": max_time, "avg": avg_time}, + "document_length": {"min": min_doc_length, "max": max_doc_length, "total": total_chars}, + "total_rows": total_rows, + } + return None + + @classmethod + def format(cls, metrics: dict[str, typing.Any]): + return ( + f"Processor {metrics['step_name']}' applied, max shard processing time {metrics['elapsed_time']['max']}," + f" number of rows remained in the dataset {metrics['total_rows']}," + f" number of characters remained in the dataset {metrics['document_length']['total']}" + ) + + @classmethod + def save_as_yaml(cls, file_name: pathlib.Path, metrics: list[dict[str, typing.Any]]): + import yaml + + with file_name.with_suffix(".yaml").open("wt") as f: + yaml.safe_dump(metrics, f) diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py new file mode 100644 index 00000000..82cc1ca1 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py @@ -0,0 +1,163 @@ +import collections +import datasets +import logging +import re + + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import ( + DocLengthFilterProcessorConfig, + NGramRepetitionFilterProcessorConfig, + FrequencyBasedFilterProcessorConfig, + BinaryContentFilterProcessorConfig, + NumericalContentFilterProcessorConfig, + PiiRedactionProcessorConfig, + MalwareRemovalProcessorConfig, +) + + +logger = logging.getLogger(__name__) + +WORD_PATTERN = r"\b\w+(?:'\w+)?\b" +NUMBER_PATTERN = r"\b\d+\b" + + +def apply_doc_length_filter_processor( + config: DocLengthFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + return dataset.filter( + lambda batch: [ + config.min_length_chars <= len(text) <= config.max_length_chars for text in batch[config.field] + ], + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_ngram_repetition_filter_processor( + config: NGramRepetitionFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def has_repeated_ngrams(batch): + results = [] + word_pattern = re.compile(WORD_PATTERN) + for text in batch[config.field]: + words = word_pattern.findall(text) + ngrams = [tuple(words[i : i + config.n]) for i in range(len(words) - config.n + 1)] + ngram_counts = collections.Counter(ngrams) + results.append(max(ngram_counts.values(), default=0) <= config.max_repetitions) + return results + + return dataset.filter( + has_repeated_ngrams, + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_frequency_based_filter_processor( + config: FrequencyBasedFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def exceeds_word_frequency_threshold(batch): + results = [] + word_pattern = re.compile(WORD_PATTERN) + for text in batch[config.field]: + words = word_pattern.findall(text) + total_words = len(words) + word_counts = collections.Counter(words) + most_common = word_counts.most_common(2) + + if most_common and (most_common[0][1] / total_words) > config.max_single_word_ratio: + results.append(False) + elif ( + len(most_common) > 1 + and ((most_common[0][1] + most_common[1][1]) / total_words) > config.max_top_two_word_ratio + ): + results.append(False) + else: + results.append(True) + return results + + return dataset.filter( + exceeds_word_frequency_threshold, + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_binary_content_filter_processor( + config: BinaryContentFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def is_binary(batch): + return [ + not sum(1 for char in text if char.isprintable()) / len(text) < config.max_bin_ratio + for text in batch[config.field] + ] + + return dataset.filter(is_binary, num_proc=config.num_proc, batched=True, batch_size=config.batch_size) + + +def apply_numerical_content_filter_processor( + config: NumericalContentFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def exceeds_numeric_threshold(batch): + results = [] + number_pattern = re.compile(NUMBER_PATTERN) + for text in batch[config.field]: + tokens = number_pattern.findall(text) + results.append((len(tokens) / max(1, len(text.split()))) <= config.max_numeric_token_ratio) + return results + + return dataset.filter( + exceeds_numeric_threshold, num_proc=config.num_proc, batched=True, batch_size=config.batch_size + ) + + +def apply_pii_redaction_processor( + config: PiiRedactionProcessorConfig, + distributed_condig: DatasetPreparatorDistributedConfig, + dataset: datasets.Dataset, +) -> datasets.Dataset: + # TODO: check if multiprocessing is possible + # TODO: manage explicit model download and loading as now it + # internally install a python package which is not transferable to workrs + + from presidio_analyzer import AnalyzerEngine + from presidio_anonymizer import AnonymizerEngine + + analyzer = AnalyzerEngine() + anonymizer = AnonymizerEngine() + + def redact_pii(batch): + results = [] + for text in batch[config.field]: + entities = analyzer.analyze( + text=text, entities=["PERSON", "EMAIL_ADDRESS", "PHONE_NUMBER", "CREDIT_CARD"], language="en" + ) + if config.redaction_method == "remove": + for result in reversed(entities): + text = text[: result.start] + "" + text[result.end :] + elif config.redaction_method == "mask": + text = anonymizer.anonymize(text, entities).text + else: + raise ValueError(f"Unkown redaction method: {config.redaction_method}") + results.append(text) + return {config.field: results} + + return dataset.map(redact_pii, num_proc=None, batched=True, batch_size=config.batch_size) + + +def apply_malware_removal_processor( + config: MalwareRemovalProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + # TODO: this is not working, scan_bytes does not exist. + # Rewrite either with downloading virus definitions file, + # loading dataset and running a file or use clamav directly + import clamav + + def is_malicious(batch): + return [not clamav.scan_bytes(text.encode()) for text in batch[config.field]] + + return dataset.filter(is_malicious, num_proc=config.num_proc, batched=True, batch_size=config.batch_size) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df..a643a412 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,6 +28,8 @@ from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import HFProcessor +from fast_llm.data.preparator.gpt_memmap.hf_processors.processor_metrics_logger import ProcessorMetricsLogger logger = logging.getLogger(__name__) @@ -221,6 +223,24 @@ def run(self) -> None: else: tokenize_fn = self._tokenize_batch + # Process dataset before tokenizing + metrics = [] + pml = ProcessorMetricsLogger() + for processor_config_field_name in self._config.processors.order: + processor: HFProcessor = self._config.processors.get_processor_types_map()[processor_config_field_name]( + config=getattr(self._config.processors, processor_config_field_name), + distributed_config=self._config.distributed, + ) + pml.start() + dataset = processor.apply(dataset) + processor_metrics = pml.stop(dataset, processor.config.human_readable_name) + metrics.append(processor_metrics) + if self._config.distributed.rank == 0: + logger.info(ProcessorMetricsLogger.format(processor_metrics)) + + if self._config.distributed.rank == 0: + ProcessorMetricsLogger.save_as_yaml(pathlib.Path(self._config.output_path) / "processors_log", metrics) + # Tokenize the dataset in parallel tokenized_dataset = dataset.map( tokenize_fn, diff --git a/tests/data/test_prepare_hf_processors.py b/tests/data/test_prepare_hf_processors.py new file mode 100644 index 00000000..bfaf3424 --- /dev/null +++ b/tests/data/test_prepare_hf_processors.py @@ -0,0 +1,86 @@ +import datasets + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import ( + DocLengthFilterProcessorConfig, + NGramRepetitionFilterProcessorConfig, + FrequencyBasedFilterProcessorConfig, + BinaryContentFilterProcessorConfig, + NumericalContentFilterProcessorConfig, + PiiRedactionProcessorConfig, + MalwareRemovalProcessorConfig, + DocLengthFilterProcessor, + NGramRepetitionFilterProcessor, + FrequencyBasedFilterProcessor, + BinaryContentFilterProcessor, + NumericalContentFilterProcessor, + PiiRedactionProcessor, + MalwareRemovalProcessor, +) + + +def create_test_dataset(data): + return datasets.Dataset.from_dict({"text": data}) + + +def test_doc_length_filter_processor(): + dataset = create_test_dataset(["short", "this is a medium length sentence", "this is a very long text" * 100]) + config = DocLengthFilterProcessorConfig(min_length_chars=10, max_length_chars=50, field="text") + processor = DocLengthFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only one entry should match the criteria + + +def test_ngram_repetition_filter_processor(): + dataset = create_test_dataset( + ["word word word", "word word word word", "unique words here", "repeat repeat repeat repeat repeat"] + ) + config = NGramRepetitionFilterProcessorConfig(n=2, max_repetitions=2, field="text") + processor = NGramRepetitionFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 2 # Only "word word word" and "unique words here" should remain + + +def test_frequency_based_filter_processor(): + dataset = create_test_dataset(["hello hello hello world", "this is fine just because", "spam spam spam spam spam"]) + config = FrequencyBasedFilterProcessorConfig(max_single_word_ratio=0.4, max_top_two_word_ratio=0.6, field="text") + processor = FrequencyBasedFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only "this is fine" should remain + + +def test_binary_content_filter_processor(): + dataset = create_test_dataset(["hello world", b"\x00\x00\x01\x02bin".decode("utf8"), "normal text"]) + config = BinaryContentFilterProcessorConfig(max_bin_ratio=0.5, field="text") + processor = BinaryContentFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 2 # Binary data should be removed + + +def test_numerical_content_filter_processor(): + dataset = create_test_dataset( + ["123 456 789", "some words and 123", "almost all numbers 123 456 789 101112 131415"] + ) + config = NumericalContentFilterProcessorConfig(max_numeric_token_ratio=0.5, field="text") + processor = NumericalContentFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only "some words and 123" should remain + + +# TODO: Make optional conditioned on library installed +def test_pii_redaction_processor(): + dataset = create_test_dataset(["My name is John Doe", "Contact me at john@example.com", "This is safe text"]) + config = PiiRedactionProcessorConfig(redaction_method="remove", field="text") + processor = PiiRedactionProcessor(config, DatasetPreparatorDistributedConfig()) + processed_dataset = processor.apply(dataset) + assert "John Doe" not in processed_dataset["text"] + assert "john@example.com" not in processed_dataset["text"] + + +# TODO: Make optional conditioned on library installed +# def test_malware_removal_processor(): +# dataset = create_test_dataset(["malicious_code();", "safe text", "virus_payload();"]) +# config = MalwareRemovalProcessorConfig(field="text") +# processor = MalwareRemovalProcessor(config, DatasetPreparatorDistributedConfig()) +# filtered_dataset = processor.apply(dataset) +# assert len(filtered_dataset) == 1 # Only "safe text" should remain