Skip to content

Commit

Permalink
update utterance synthesizer
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Jan 28, 2025
1 parent 21957b4 commit 1681e77
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 167 deletions.
32 changes: 24 additions & 8 deletions autointent/generation/utterances/basic/chat_template.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
"""Chat template for evolution augmentation via abstractization."""

import random
from abc import ABC, abstractmethod
from typing import ClassVar

from autointent import Dataset
from autointent.generation.utterances.schemas import Message, Role
from autointent.schemas import Intent


class ExampleGenerator:
class BaseSynthesizer(ABC):
"""Base class."""

@abstractmethod
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
"""Generate examples for this intent."""


class SynthesizerChatTemplate(BaseSynthesizer):
"""Chat template for generating additional examples for a given intent class."""

_messages: ClassVar[list[Message]] = [
Expand Down Expand Up @@ -86,7 +96,13 @@ class ExampleGenerator:
),
]

def __init__(self, dataset: Dataset, split: str, extra_instructions: str | None = None) -> None:
def __init__(
self,
dataset: Dataset,
split: str,
extra_instructions: str | None = None,
max_sample_utterances: int | None = None,
) -> None:
"""Initialize."""
if extra_instructions is None:
extra_instructions = ""
Expand All @@ -96,19 +112,19 @@ def __init__(self, dataset: Dataset, split: str, extra_instructions: str | None

self.dataset = dataset
self.split = split
self.max_sample_utterances = max_sample_utterances

def __call__(self, intent_id: int, n_examples: int, max_sample_utterances: int | None = None) -> list[Message]:
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
"""Generate additional examples for the provided intent class."""
filtered_split = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] == intent_id)
filtered_split = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] == intent_data.id)
sample_utterances = filtered_split[Dataset.utterance_feature]
intent = next(i for i in self.dataset.intents if i.id == intent_id)
if max_sample_utterances is not None:
sample_utterances = random.sample(sample_utterances, k=max_sample_utterances)
if self.max_sample_utterances is not None:
sample_utterances = random.sample(sample_utterances, k=self.max_sample_utterances)
return [
*self._messages,
Message(
role=Role.USER,
content=f"Intent name: {intent.name}\n\n"
content=f"Intent name: {intent_data.name}\n\n"
f"Example Utterances:\n{sample_utterances}\n\n"
f"Please generate {n_examples} more examples for the provided intent class.\n",
),
Expand Down
35 changes: 6 additions & 29 deletions autointent/generation/utterances/basic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from argparse import ArgumentParser

from autointent import load_dataset
from autointent.generation.utterances.basic.utterance_generator import LengthType, StyleType, UtteranceGenerator
from autointent.generation.utterances.basic.utterance_generator import UtteranceGenerator
from autointent.generation.utterances.generator import Generator

from .chat_template import SynthesizerChatTemplate


def main() -> None:
"""ClI endpoint."""
Expand Down Expand Up @@ -41,37 +43,12 @@ def main() -> None:
default=5,
help="Number of utterances to use as an example for augmentation",
)
parser.add_argument(
"--custom-instruction",
type=str,
action="append",
help="Add extra instructions to default prompt."
"You can use this argument multiple times to add multiple instructions",
)
parser.add_argument(
"--length",
choices=LengthType.__args__, # type: ignore[attr-defined]
default="none",
help="How to extend the prompt with length instruction",
)
parser.add_argument(
"--style",
choices=StyleType.__args__, # type: ignore[attr-defined]
default="none",
help="How to extend the prompt with style instruction",
)
parser.add_argument(
"--same-punctuation",
action="store_true",
help="Whether to extend the prompt with punctuation instruction",
)
args = parser.parse_args()

dataset = load_dataset(args.input_path)
generator = UtteranceGenerator(
Generator(), args.custom_instruction or [], args.length, args.style, args.same_punctuation
)
generator.augment(dataset, n_generations=args.n_generations, max_sample_utterances=args.n_sample_utterances)
template = SynthesizerChatTemplate(dataset, "train", max_sample_utterances=args.n_sample_utterances)
generator = UtteranceGenerator(Generator(), template)
generator.augment(dataset, n_generations=args.n_generations)

dataset.to_json(args.output_path)

Expand Down
14 changes: 0 additions & 14 deletions autointent/generation/utterances/basic/extra_instructions.json

This file was deleted.

97 changes: 8 additions & 89 deletions autointent/generation/utterances/basic/utterance_generator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
"""Basic generation of new utterances from existing ones."""

import importlib.resources as ires
import json
import random
from typing import Any, Literal
from collections.abc import Callable

import yaml
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets

from autointent import Dataset
from autointent.custom_types import Split
from autointent.generation.utterances.generator import Generator
from autointent.generation.utterances.utils import safe_format # type: ignore[attr-defined]
from autointent.schemas import Sample

LengthType = Literal["none", "same", "longer", "shorter"]
StyleType = Literal["none", "formal", "informal", "playful"]
from autointent.generation.utterances.schemas import Message
from autointent.schemas import Intent, Sample


class UtteranceGenerator:
Expand All @@ -28,34 +21,14 @@ class UtteranceGenerator:
punctuation and length of the desired generations.
"""

def __init__(
self,
generator: Generator,
custom_instruction: list[str],
length: LengthType,
style: StyleType,
same_punctuation: bool,
) -> None:
def __init__(self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]]) -> None:
"""Initialize."""
self.generator = generator
prompt_template_yaml = _load_prompt()
self.prompt_template_yaml = _add_extra_instructions(
prompt_template_yaml,
custom_instruction,
length,
style,
same_punctuation,
)
self.prompt_maker = prompt_maker

def __call__(self, intent_name: str, example_utterances: list[str], n_generations: int) -> list[str]:
def __call__(self, intent_data: Intent, n_generations: int) -> list[str]:
"""Generate new utterances."""
messages_yaml = safe_format(
self.prompt_template_yaml,
intent_name=intent_name,
example_utterances=_format_utterances(example_utterances),
n_examples=n_generations,
)
messages = yaml.safe_load(messages_yaml)
messages = self.prompt_maker(intent_data, n_generations)
response_text = self.generator.get_chat_completion(messages)
return _extract_utterances(response_text)

Expand All @@ -64,7 +37,6 @@ def augment(
dataset: Dataset,
split_name: str = Split.TRAIN,
n_generations: int = 5,
max_sample_utterances: int = 5,
update_split: bool = True,
) -> list[Sample]:
"""
Expand All @@ -75,13 +47,8 @@ def augment(
original_split = dataset[split_name]
new_samples = []
for intent in dataset.intents:
filtered_split = original_split.filter(lambda sample, id=intent.id: sample[Dataset.label_feature] == id)
sample_utterances = filtered_split[Dataset.utterance_feature]
if max_sample_utterances is not None:
sample_utterances = random.sample(sample_utterances, k=max_sample_utterances)
generated_utterances = self(
intent_name=intent.name or "",
example_utterances=sample_utterances,
intent_data=intent,
n_generations=n_generations,
)
new_samples.extend(
Expand All @@ -93,54 +60,6 @@ def augment(
return [Sample(**sample) for sample in new_samples]


def _load_prompt() -> str:
with ires.files("autointent.generation.utterances.basic").joinpath("chat_template.yaml").open() as file:
return file.read()


def _load_extra_instructions() -> dict[str, Any]:
with ires.files("autointent.generation.utterances.basic").joinpath("extra_instructions.json").open() as file:
return json.load(file) # type: ignore[no-any-return]


def _add_extra_instructions(
prompt_template_yaml: str,
custom_instruction: list[str],
length: LengthType,
style: StyleType,
same_punctuation: bool,
) -> str:
instructions = _load_extra_instructions()

extra_instructions = []
if length != "none":
extra_instructions.append(instructions["length"][length])
if style != "none":
extra_instructions.append(instructions["style"][style])
if same_punctuation:
extra_instructions.append(instructions["punctuation"])

extra_instructions.extend(custom_instruction)

parsed_extra_instructions = "\n ".join([f"- {s}" for s in extra_instructions])
return safe_format(prompt_template_yaml, extra_instructions=parsed_extra_instructions) # type: ignore[no-any-return]


def _format_utterances(utterances: list[str]) -> str:
"""
Convert given utterances into string that is ready to insert into prompt.
Given list of utterances, the output string is returned in the following format:
.. code-block::
1. I want to order a large pepperoni pizza.
2. Can I get a medium cheese pizza with extra olives?
3. Please deliver a small veggie pizza to my address.
Note that tab is inserted before each line because of how yaml processes multi-line fields.
"""
return "\n ".join(f"{i}. {ut}" for i, ut in enumerate(utterances))


def _extract_utterances(response_text: str) -> list[str]:
"""
Parse LLM output.
Expand Down
27 changes: 0 additions & 27 deletions autointent/generation/utterances/utils.py

This file was deleted.

0 comments on commit 1681e77

Please sign in to comment.