diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 9868ab52..4b0491bb 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any, Literal, Optional, Union +import random +import time import yaml from datasets import ( Dataset, @@ -138,6 +140,8 @@ def __init__( self.text_creator = EndlessTextCreator( data=config.source, ) + # Add counter for unique prefixes + self.request_counter = 0 def __iter__( self, @@ -170,22 +174,42 @@ def __iter__( output_tokens_sampler, ): start_index = rand.randint(0, len(self.text_creator.words)) + # Increment counter for each request + self.request_counter += 1 yield { - "prompt": self._create_prompt(prompt_tokens, start_index), + "prompt": self._create_prompt(prompt_tokens, start_index, self.request_counter), "prompt_tokens_count": prompt_tokens, "output_tokens_count": output_tokens, } - def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: + def _create_prompt(self, prompt_tokens: int, start_index: int, request_id: int) -> str: + """ + Create a prompt with unique prefix to prevent vLLM prefix caching. + Args: + prompt_tokens: Target number of tokens for the prompt + start_index: Starting position in the text corpus + request_id: Unique identifier for this request (used as prefix) + Returns: + Generated prompt string with unique prefix + """ if prompt_tokens <= 0: - return "" + return self._create_unique_prefix(request_id) + + unique_prefix = self._create_unique_prefix(request_id) + + # Calculate how many tokens the prefix uses + prefix_tokens = len(self.processor.tokenize(unique_prefix)) + + # Adjust target tokens to account for the prefix + remaining_tokens = max(1, prompt_tokens - prefix_tokens) left = start_index - right = start_index + 4 * prompt_tokens + right = start_index + 4 * remaining_tokens while left < right: mid = (left + right) // 2 - test_prompt = self.text_creator.create_text(start_index, mid - start_index) + base_text = self.text_creator.create_text(start_index, mid - start_index) + test_prompt = unique_prefix + base_text test_tokens = len(self.processor.tokenize(test_prompt)) if test_tokens == prompt_tokens: @@ -195,7 +219,23 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: else: right = mid - return self.text_creator.create_text(start_index, left - start_index) + base_text = self.text_creator.create_text(start_index, left - start_index) + return unique_prefix + base_text + + def _create_unique_prefix(self, request_id: int) -> str: + """ + Create a unique prefix that will never match any other request. + """ + + timestamp = int(time.time() * 1000000) # microseconds + random.seed(request_id + timestamp) + random_component = random.randint(100000, 999999) + + prefix_parts = [ + f"RAND{random_component}", + ] + + return f"{'_'.join(prefix_parts)}: " class SyntheticDatasetCreator(DatasetCreator): diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py new file mode 100644 index 00000000..cbe130f4 --- /dev/null +++ b/tests/unit/dataset/test_synthetic.py @@ -0,0 +1,569 @@ +""" +Unit tests for guidellm.dataset.synthetic module. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +import yaml +from datasets import Dataset +from transformers import AutoTokenizer + +from guidellm.dataset.synthetic import ( + SyntheticDatasetConfig, + SyntheticDatasetCreator, + SyntheticTextItemsGenerator, +) + + +class TestSyntheticDatasetConfig: + + def test_config_creation_with_defaults(self): + config = SyntheticDatasetConfig( + prompt_tokens=50, + output_tokens=20 + ) + + assert config.prompt_tokens == 50 + assert config.output_tokens == 20 + assert config.samples == 1000 # default + assert config.source == "data:prideandprejudice.txt.gz" # default + assert config.prompt_tokens_stdev is None + assert config.prompt_tokens_min is None + assert config.prompt_tokens_max is None + + def test_config_creation_with_all_params(self): + config = SyntheticDatasetConfig( + prompt_tokens=100, + prompt_tokens_stdev=10, + prompt_tokens_min=50, + prompt_tokens_max=150, + output_tokens=30, + output_tokens_stdev=5, + output_tokens_min=20, + output_tokens_max=40, + samples=500, + source="custom_text.txt" + ) + + assert config.prompt_tokens == 100 + assert config.prompt_tokens_stdev == 10 + assert config.prompt_tokens_min == 50 + assert config.prompt_tokens_max == 150 + assert config.output_tokens == 30 + assert config.output_tokens_stdev == 5 + assert config.output_tokens_min == 20 + assert config.output_tokens_max == 40 + assert config.samples == 500 + assert config.source == "custom_text.txt" + + def test_parse_json_string(self): + json_str = json.dumps({ + "prompt_tokens": 75, + "output_tokens": 25, + "samples": 200, + "source": "test.txt" + }) + + config = SyntheticDatasetConfig.parse_str(json_str) + + assert config.prompt_tokens == 75 + assert config.output_tokens == 25 + assert config.samples == 200 + assert config.source == "test.txt" + + def test_parse_key_value_pairs(self): + kv_str = "prompt_tokens=80,output_tokens=30,samples=300,source=data.txt" + + config = SyntheticDatasetConfig.parse_str(kv_str) + + assert config.prompt_tokens == 80 + assert config.output_tokens == 30 + assert config.samples == 300 + assert config.source == "data.txt" + + def test_parse_yaml_file(self): + + config_data = { + "prompt_tokens": 60, + "output_tokens": 15, + "samples": 100, + "source": "yaml_test.txt" + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + yaml_path = f.name + + try: + config = SyntheticDatasetConfig.parse_str(yaml_path) + + assert config.prompt_tokens == 60 + assert config.output_tokens == 15 + assert config.samples == 100 + assert config.source == "yaml_test.txt" + finally: + Path(yaml_path).unlink() + + def test_parse_config_file(self): + config_data = { + "prompt_tokens": 90, + "output_tokens": 35, + "samples": 150 + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.config', delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + config = SyntheticDatasetConfig.parse_str(config_path) + + assert config.prompt_tokens == 90 + assert config.output_tokens == 35 + assert config.samples == 150 + finally: + Path(config_path).unlink() + + def test_parse_invalid_format(self): + with pytest.raises(ValueError, match="Unsupported data format"): + SyntheticDatasetConfig.parse_str("invalid_format_string") + + def test_validation_positive_values(self): + """Test that negative values are rejected.""" + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=-1, output_tokens=20) + + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=20, output_tokens=-1) + + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=-1) + + +class TestSyntheticTextItemsGenerator: + + @pytest.fixture + def tokenizer(self): + """Fixture to provide a tokenizer for testing.""" + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + @pytest.fixture + def simple_config(self): + + return SyntheticDatasetConfig( + prompt_tokens=15, + output_tokens=10, + samples=5, + source="The quick brown fox jumps over the lazy dog. Machine learning models require diverse training data." + ) + + @pytest.fixture + def complex_config(self): + + return SyntheticDatasetConfig( + prompt_tokens=20, + prompt_tokens_stdev=5, + prompt_tokens_min=10, + prompt_tokens_max=30, + output_tokens=15, + output_tokens_stdev=3, + output_tokens_min=10, + output_tokens_max=20, + samples=10, + source="The quick brown fox jumps over the lazy dog. Machine learning models require diverse training data." + ) + + def test_generator_initialization(self, simple_config, tokenizer): + + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + assert generator.config == simple_config + assert generator.processor == tokenizer + assert generator.random_seed == 42 + assert generator.request_counter == 0 + assert generator.text_creator is not None + + def test_basic_prompt_generation(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items = list(generator) + + # Verify we get the expected number of items + assert len(items) == simple_config.samples + + # Verify each item has the required keys + for item in items: + assert "prompt" in item + assert "prompt_tokens_count" in item + assert "output_tokens_count" in item + + # Verify types + assert isinstance(item["prompt"], str) + assert isinstance(item["prompt_tokens_count"], int) + assert isinstance(item["output_tokens_count"], int) + + # Verify non-empty prompt + assert len(item["prompt"]) > 0 + + def test_unique_prefix_generation(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items = list(generator) + prompts = [item["prompt"] for item in items] + + # Verify each prompt starts with a unique request ID + for i, prompt in enumerate(prompts, 1): + assert prompt.startswith(f"{i}: "), f"Prompt {i} should start with '{i}: ', got '{prompt[:10]}...'" + + # Verify no two prompts are identical + assert len(set(prompts)) == len(prompts), "All prompts should be unique" + + def test_prefix_caching_prevention(self, simple_config, tokenizer): + """Test that prefix caching is effectively prevented.""" + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items = list(generator) + prompts = [item["prompt"] for item in items] + + # Test that no prompt is a prefix of another + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts): + if i != j: + assert not prompt1.startswith(prompt2), f"Prompt {i} starts with prompt {j}" + assert not prompt2.startswith(prompt1), f"Prompt {j} starts with prompt {i}" + + # Test that first characters are all different + first_chars = [prompt[0] for prompt in prompts] + assert len(set(first_chars)) == len(first_chars), "First characters should all be different" + + def test_token_count_accuracy(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items = list(generator) + + for item in items: + actual_tokens = len(tokenizer.tokenize(item["prompt"])) + target_tokens = item["prompt_tokens_count"] + + # Allow small variance due to tokenization differences + assert abs(actual_tokens - target_tokens) <= 2, \ + f"Token count mismatch: expected ~{target_tokens}, got {actual_tokens}" + + def test_variance_in_token_counts(self, complex_config, tokenizer): + generator = SyntheticTextItemsGenerator(complex_config, tokenizer, random_seed=42) + + items = list(generator) + + prompt_token_counts = [item["prompt_tokens_count"] for item in items] + output_token_counts = [item["output_tokens_count"] for item in items] + + # With variance, we should see different token counts + assert len(set(prompt_token_counts)) > 1, "Should have variance in prompt token counts" + assert len(set(output_token_counts)) > 1, "Should have variance in output token counts" + + # Verify bounds are respected + assert all(complex_config.prompt_tokens_min <= count <= complex_config.prompt_tokens_max + for count in prompt_token_counts), "Prompt tokens should be within bounds" + assert all(complex_config.output_tokens_min <= count <= complex_config.output_tokens_max + for count in output_token_counts), "Output tokens should be within bounds" + + def test_reproducibility_with_same_seed(self, simple_config, tokenizer): + generator1 = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + generator2 = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items1 = list(generator1) + items2 = list(generator2) + + # Results should be identical with same seed + assert len(items1) == len(items2) + for item1, item2 in zip(items1, items2): + assert item1["prompt"] == item2["prompt"] + assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"] + assert item1["output_tokens_count"] == item2["output_tokens_count"] + + def test_different_seeds_produce_different_results(self, simple_config, tokenizer): + """Test that different seeds produce different results.""" + generator1 = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + generator2 = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=123) + + items1 = list(generator1) + items2 = list(generator2) + + # Results should be different with different seeds + prompts1 = [item["prompt"] for item in items1] + prompts2 = [item["prompt"] for item in items2] + + different_content = False + for p1, p2 in zip(prompts1, prompts2): + # Remove the prefix and compare content + content1 = p1.split(": ", 1)[1] if ": " in p1 else p1 + content2 = p2.split(": ", 1)[1] if ": " in p2 else p2 + if content1 != content2: + different_content = True + break + + assert different_content, "Different seeds should produce different content" + + def test_create_prompt_method_directly(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + # Test normal prompt creation + prompt = generator._create_prompt(10, 0, 5) + assert prompt.startswith("5: "), "Prompt should start with request ID" + + actual_tokens = len(tokenizer.tokenize(prompt)) + assert abs(actual_tokens - 10) <= 1, "Token count should be approximately correct" + + # Test empty prompt + empty_prompt = generator._create_prompt(0, 0, 3) + assert empty_prompt == "3: ", "Empty prompt should just be the prefix" + + def test_request_counter_increments_correctly(self, simple_config, tokenizer): + + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + # Initially should be 0 + assert generator.request_counter == 0 + + # Get items one by one and check counter + items = [] + for i, item in enumerate(generator, 1): + items.append(item) + # Counter should increment for each item + assert generator.request_counter == i + if i >= 3: # Just test first 3 + break + + # Verify prompts have correct prefixes + for i, item in enumerate(items, 1): + assert item["prompt"].startswith(f"{i}: ") + + def test_prefix_format_consistency(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + items = list(generator) + + for i, item in enumerate(items, 1): + prompt = item["prompt"] + + # Should start with number followed by colon and space + assert prompt.startswith(f"{i}: "), f"Prompt should start with '{i}: '" + + # Should be able to split on ': ' to get request ID and content + parts = prompt.split(": ", 1) + assert len(parts) == 2, "Prompt should have exactly one ': ' separator" + assert parts[0] == str(i), f"First part should be request ID {i}" + + # Content part should not be empty (unless it's a zero-token prompt) + if item["prompt_tokens_count"] > 0: + assert len(parts[1]) > 0, "Content part should not be empty for non-zero token prompts" + + def test_binary_search_token_accuracy(self, simple_config, tokenizer): + + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + # Test various token counts + test_cases = [5, 10, 15, 20, 25] + + for target_tokens in test_cases: + prompt = generator._create_prompt(target_tokens, 0, 999) + actual_tokens = len(tokenizer.tokenize(prompt)) + + # Should be very close to target (allowing for small tokenization differences) + assert abs(actual_tokens - target_tokens) <= 1, \ + f"Target: {target_tokens}, Actual: {actual_tokens}, Prompt: '{prompt[:50]}...'" + + def test_vllm_cache_simulation_comprehensive(self, simple_config, tokenizer): + + # Use larger sample for more thorough testing + config = SyntheticDatasetConfig( + prompt_tokens=20, + output_tokens=10, + samples=20, + source=simple_config.source + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + prompts = [item["prompt"] for item in items] + + # Simulate vLLM cache with different granularities + cache_scenarios = [ + {"name": "Character-level", "granularity": 1}, + {"name": "Token-level", "granularity": 4}, + {"name": "Word-level", "granularity": 10}, + ] + + for scenario in cache_scenarios: + cache_hits = 0 + total_comparisons = 0 + + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts[i+1:], i+1): + total_comparisons += 1 + + # Check for common prefix at specified granularity + min_len = min(len(prompt1), len(prompt2)) + common_prefix_len = 0 + + for k in range(0, min_len, scenario["granularity"]): + chunk1 = prompt1[k:k+scenario["granularity"]] + chunk2 = prompt2[k:k+scenario["granularity"]] + if chunk1 == chunk2: + common_prefix_len += len(chunk1) + else: + break + + # If meaningful common prefix exists, it's a cache hit + if common_prefix_len > scenario["granularity"]: + cache_hits += 1 + + cache_hit_rate = (cache_hits / total_comparisons) * 100 if total_comparisons > 0 else 0 + + # All scenarios should have 0% cache hit rate + assert cache_hit_rate == 0.0, \ + f"{scenario['name']} caching: Expected 0% hit rate, got {cache_hit_rate:.1f}%" + + def test_edge_case_very_short_prompts(self, tokenizer): + config = SyntheticDatasetConfig( + prompt_tokens=1, + output_tokens=5, + samples=5, + source="A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + + for i, item in enumerate(items, 1): + # Even very short prompts should have unique prefixes + assert item["prompt"].startswith(f"{i}: ") + + # Should have at least the prefix + assert len(item["prompt"]) >= len(f"{i}: ") + + + def test_create_prompt_method_signature_and_documentation(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator(simple_config, tokenizer, random_seed=42) + + # Test method exists and is callable + assert hasattr(generator, '_create_prompt') + assert callable(generator._create_prompt) + + # Test method signature by calling with expected parameters + prompt = generator._create_prompt( + prompt_tokens=10, + start_index=0, + request_id=1 + ) + + # Should return a string + assert isinstance(prompt, str) + + # Should start with the request ID + assert prompt.startswith("1: ") + + # Test that docstring exists and mentions key concepts + docstring = generator._create_prompt.__doc__ + assert docstring is not None + assert "prefix" in docstring.lower() + assert "cache" in docstring.lower() or "caching" in docstring.lower() + assert "request_id" in docstring + + +class TestIntegration: + """Integration tests for the complete synthetic dataset workflow.""" + + @pytest.fixture + def tokenizer(self): + """Fixture to provide a tokenizer for testing.""" + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def test_end_to_end_workflow(self, tokenizer): + """Test the complete workflow from config to dataset.""" + # Create configuration + config_dict = { + "prompt_tokens": 20, + "output_tokens": 15, + "samples": 10, + "source": "The quick brown fox jumps over the lazy dog. Machine learning models require diverse training data to perform well across different tasks and domains." + } + + config_str = json.dumps(config_dict) + + # Create dataset + dataset = SyntheticDatasetCreator.handle_create( + data=config_str, + data_args=None, + processor=tokenizer, + processor_args=None, + random_seed=42 + ) + + # Verify dataset properties + assert isinstance(dataset, Dataset) + assert len(dataset) == 10 + + # Verify all prompts are unique and have correct prefixes + prompts = dataset["prompt"] + for i, prompt in enumerate(prompts, 1): + assert prompt.startswith(f"{i}: "), f"Prompt {i} should start with '{i}: '" + + # Verify no cache hits would occur + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts): + if i != j: + assert not prompt1.startswith(prompt2) + assert not prompt2.startswith(prompt1) + + # Verify token counts are reasonable + for i, row in enumerate(dataset): + actual_tokens = len(tokenizer.tokenize(row["prompt"])) + target_tokens = row["prompt_tokens_count"] + assert abs(actual_tokens - target_tokens) <= 2, \ + f"Row {i}: token count mismatch" + + def test_cache_prevention_effectiveness(self, tokenizer): + """Test that the cache prevention is effective across larger datasets.""" + config = SyntheticDatasetConfig( + prompt_tokens=25, + output_tokens=20, + samples=50, + source="The quick brown fox jumps over the lazy dog. Machine learning models require diverse training data." + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + prompts = [item["prompt"] for item in items] + # print(prompts) + # Simulate vLLM cache behavior + cache_hits = 0 + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts[i+1:], i+1): + # Check for any common prefix + common_prefix_len = 0 + for k, (c1, c2) in enumerate(zip(prompt1, prompt2)): + if c1 == c2: + common_prefix_len += 1 + else: + break + + # If there's any meaningful common prefix, it's a potential cache hit + if common_prefix_len > 3: # More than just a few characters + cache_hits += 1 + + # With proper prefix prevention, there should be no cache hits + total_comparisons = len(prompts) * (len(prompts) - 1) // 2 + cache_hit_rate = cache_hits / total_comparisons if total_comparisons > 0 else 0 + + assert cache_hit_rate == 0.0, f"Expected 0% cache hit rate, got {cache_hit_rate:.2%}" \ No newline at end of file