diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..9133422 --- /dev/null +++ b/config.yml @@ -0,0 +1,72 @@ +save_dir: "./experiment/" + +ablation: + use_ablate: false + +# Data Ingestion ------------------- +data: + file_type: "huggingface" # one of 'json', 'csv', 'huggingface' + path: "yahma/alpaca-cleaned" + prompt: + >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data + Below is an instruction that describes a task. + Write a response that appropriately completes the request. + ### Instruction: {instruction} + ### Input: {input} + ### Output: + prompt_stub: + >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present + {output} + test_size: 0.1 # Proportion of test as % of total; if integer then # of samples + train_size: 0.9 # Proportion of train as % of total; if integer then # of samples + train_test_split_seed: 42 + +# Model Definition ------------------- +model: + hf_model_ckpt: "NousResearch/Llama-2-7b-hf" + quantize: true + bitsandbytes: + load_in_4bit: true + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_quant_type: "nf4" + +# LoRA Params ------------------- +lora: + task_type: "CAUSAL_LM" + r: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + - up_proj + - down_proj + - gate_proj + +# Training ------------------- +training: + training_args: + num_train_epochs: 5 + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + gradient_checkpointing: True + optim: "paged_adamw_32bit" + logging_steps: 100 + learning_rate: 2.0e-4 + bf16: true # Set to true for mixed precision training on Newer GPUs + tf32: true + # fp16: false # Set to true for mixed precision training on Older GPUs + max_grad_norm: 0.3 + warmup_ratio: 0.03 + lr_scheduler_type: "constant" + sft_args: + max_seq_length: 5000 + # neftune_noise_alpha: None + +inference: + max_new_tokens: 1024 + use_cache: True + do_sample: True + top_p: 0.9 + temperature: 0.8 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/dataset_generator.py b/src/data/dataset_generator.py new file mode 100644 index 0000000..7a05c0d --- /dev/null +++ b/src/data/dataset_generator.py @@ -0,0 +1,95 @@ +import os +from os.path import join, exists +from functools import partial +from typing import Tuple, Union +import pickle + +import re +from datasets import Dataset +from rich.console import Console +from rich.layout import Layout +from rich.panel import Panel + +from src.data.ingestor import Ingestor, get_ingestor + + +class DatasetGenerator: + def __init__( + self, + file_type: str, + path: str, + prompt: str, + prompt_stub: str, + test_size: Union[float, int], + train_size: Union[float, int], + train_test_split_seed: int, + ): + self.ingestor: Ingestor = get_ingestor(file_type) + self.ingestor: Ingestor = self.ingestor(path) + + self.dataset: Dataset = self.ingestor.to_dataset() + self.prompt: str = prompt + self.prompt_stub: str = prompt_stub + self.test_size = test_size + self.train_size = train_size + self.train_test_split_seed: int = train_test_split_seed + + self.train_columns: list = self._get_train_columns() + self.test_column: str = self._get_test_column() + + def _get_train_columns(self): + pattern = r"\{([^}]*)\}" + return re.findall(pattern, self.prompt) + + def _get_test_column(self): + pattern = r"\{([^}]*)\}" + return re.findall(pattern, self.prompt_stub)[0] + + # TODO: stratify_by_column + def _train_test_split(self): + self.dataset = self.dataset.train_test_split( + test_size=self.test_size, + train_size=self.train_size, + seed=self.train_test_split_seed, + ) + + def _format_one_prompt(self, example, is_test: bool = False): + train_mapping = {var_name: example[var_name] for var_name in self.train_columns} + example["formatted_prompt"] = self.prompt.format(**train_mapping) + + if not is_test: + test_mapping = {self.test_column: example[self.test_column]} + example["formatted_prompt"] += self.prompt_stub.format(**test_mapping) + + return example + + def _format_prompts(self): + self.dataset["train"] = self.dataset["train"].map( + partial(self._format_one_prompt, is_test=False) + ) + self.dataset["test"] = self.dataset["test"].map( + partial(self._format_one_prompt, is_test=True) + ) + + def get_dataset(self) -> Tuple[Dataset, Dataset]: + self._train_test_split() + self._format_prompts() + + return self.dataset["train"], self.dataset["test"] + + def save_dataset(self, save_dir: str): + os.makedirs(save_dir, exist_ok=True) + with open(join(save_dir, "dataset.pkl"), "wb") as f: + pickle.dump(self.dataset, f) + + def load_dataset_from_pickle(self, save_dir: str): + data_path = join(save_dir, "dataset.pkl") + + if not exists(data_path): + raise FileNotFoundError(f"Train set pickle not found at {save_dir}") + + with open(data_path, "rb") as f: + data = pickle.load(f) + self.dataset = data + + return self.dataset["train"], self.dataset["test"] diff --git a/src/data/ingestor.py b/src/data/ingestor.py new file mode 100644 index 0000000..227e4d7 --- /dev/null +++ b/src/data/ingestor.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from functools import partial + +import ijson +import csv +from datasets import Dataset, load_dataset, concatenate_datasets + + +def get_ingestor(data_type: str): + if data_type == "json": + return JsonIngestor + elif data_type == "csv": + return CsvIngestor + elif data_type == "huggingface": + return HuggingfaceIngestor + else: + raise ValueError( + f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}" + ) + + +class Ingestor(ABC): + @abstractmethod + def to_dataset(self) -> Dataset: + pass + + +class JsonIngestor(Ingestor): + def __init__(self, path: str): + self.path = path + + def _json_generator(self): + with open(self.path, "rb") as f: + for item in ijson.items(f, "item"): + yield item + + def to_dataset(self) -> Dataset: + return Dataset.from_generator(self._json_generator) + + +class CsvIngestor(Ingestor): + def __init__(self, path: str): + self.path = path + + def _csv_generator(self): + with open(self.path) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + yield row + + def to_dataset(self) -> Dataset: + return Dataset.from_generator(self._csv_generator) + + +class HuggingfaceIngestor(Ingestor): + def __init__(self, path: str): + self.path = path + + def to_dataset(self) -> Dataset: + ds = load_dataset(self.path) + return concatenate_datasets(ds.values()) diff --git a/src/finetune/__init__.py b/src/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/finetune/finetune.py b/src/finetune/finetune.py new file mode 100644 index 0000000..77e053f --- /dev/null +++ b/src/finetune/finetune.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Union, List, Tuple, Dict + + +class Finetune(ABC): + @abstractmethod + def finetune(self): + pass + + @abstractmethod + def save_model(self): + pass diff --git a/src/finetune/lora.py b/src/finetune/lora.py new file mode 100644 index 0000000..c980bdf --- /dev/null +++ b/src/finetune/lora.py @@ -0,0 +1,140 @@ +from os.path import join, exists +from typing import Tuple + +import torch + +import bitsandbytes as bnb +from datasets import Dataset +from accelerate import Accelerator +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + BitsAndBytesConfig, + TrainingArguments, + AutoTokenizer, + ProgressCallback, +) +from peft import ( + prepare_model_for_kbit_training, + get_peft_model, + LoraConfig, +) +from trl import SFTTrainer +from rich.console import Console + + +from src.pydantic_models.config_model import Config +from src.utils.save_utils import DirectoryHelper +from src.finetune.finetune import Finetune +from src.ui.rich_ui import RichUI + + +class LoRAFinetune(Finetune): + def __init__(self, config: Config, directory_helper: DirectoryHelper): + self.config = config + + self._model_config = config.model + self._training_args = config.training.training_args + self._sft_args = config.training.sft_args + self._lora_config = LoraConfig(**config.lora.model_dump()) + self._directory_helper = directory_helper + self._weights_path = self._directory_helper.save_paths.weights + self._trainer = None + + self.model = None + self.tokenizer = None + + """ TODO: Figure out how to handle multi-gpu + if config.accelerate: + self.accelerator = Accelerator() + self.accelerator.state.deepspeed_plugin.deepspeed_config[ + "train_micro_batch_size_per_gpu" + ] = self.config.training.training_args.per_device_train_batch_size + + if config.accelerate: + # device_index = Accelerator().process_index + self.device_map = None #{"": device_index} + else: + """ + self.device_map = self._model_config.device_map + + self._load_model_and_tokenizer() + + def _load_model_and_tokenizer(self): + ckpt = self._model_config.hf_model_ckpt + RichUI.on_basemodel_load(ckpt) + model = self._get_model() + tokenizer = self._get_tokenizer() + RichUI.after_basemodel_load(ckpt) + + self.model = model + self.tokenizer = tokenizer + + def _get_model(self): + model = AutoModelForCausalLM.from_pretrained( + self._model_config.hf_model_ckpt, + quantization_config=( + BitsAndBytesConfig(**self._model_config.bitsandbytes.model_dump()) + if not self.config.accelerate + else None + ), + use_cache=False, + device_map=self.device_map, + ) + + model.config.pretraining_tp = 1 + + return model + + def _get_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self._model_config.hf_model_ckpt) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + + return tokenizer + + def _inject_lora(self): + if not self.config.accelerate: + self.model.gradient_checkpointing_enable() + self.model = prepare_model_for_kbit_training(self.model) + self.model = get_peft_model(self.model, self._lora_config) + + if not self.config.accelerate: + self.optimizer = bnb.optim.Adam8bit( + self.model.parameters(), lr=self._training_args.learning_rate + ) + self.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer) + if self.config.accelerate: + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + def finetune(self, train_dataset: Dataset): + logging_dir = join(self._weights_path, "/logs") + training_args = TrainingArguments( + output_dir=self._weights_path, + logging_dir=logging_dir, + report_to="none", + **self._training_args.model_dump(), + ) + + progress_callback = ProgressCallback() + + self._trainer = SFTTrainer( + model=self.model, + train_dataset=train_dataset, + peft_config=self._lora_config, + tokenizer=self.tokenizer, + packing=True, + args=training_args, + dataset_text_field="formatted_prompt", # TODO: maybe move consts to a dedicated folder + callbacks=[progress_callback], + # optimizers=[self.optimizer, self.lr_scheduler], + **self._sft_args.model_dump(), + ) + + trainer_stats = self._trainer.train() + + def save_model(self) -> None: + self._trainer.model.save_pretrained(self._weights_path) + self.tokenizer.save_pretrained(self._weights_path) diff --git a/src/inference/__init__.py b/src/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inference/inference.py b/src/inference/inference.py new file mode 100644 index 0000000..be23a80 --- /dev/null +++ b/src/inference/inference.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Union, List, Tuple, Dict + + +class Inference(ABC): + @abstractmethod + def infer_one(self, prompt: str): + pass + + @abstractmethod + def infer_all(self): + pass diff --git a/src/inference/lora.py b/src/inference/lora.py new file mode 100644 index 0000000..4e2a631 --- /dev/null +++ b/src/inference/lora.py @@ -0,0 +1,134 @@ +import os +from os.path import join +from threading import Thread +import csv + +from transformers import TextIteratorStreamer +from rich.console import Console +from rich.table import Table +from rich.live import Live +from rich.text import Text +from datasets import Dataset +from transformers import AutoTokenizer, BitsAndBytesConfig +from peft import AutoPeftModelForCausalLM +import torch + + +from src.pydantic_models.config_model import Config +from src.utils.save_utils import DirectoryHelper +from src.inference.inference import Inference +from src.ui.rich_ui import RichUI + + +# TODO: Add type hints please! +class LoRAInference(Inference): + def __init__( + self, + test_dataset: Dataset, + label_column_name: str, + config: Config, + dir_helper: DirectoryHelper, + ): + self.test_dataset = test_dataset + self.label_column = label_column_name + self.config = config + + self.save_dir = dir_helper.save_paths.results + self.save_path = join(self.save_dir, "results.csv") + self.device_map = self.config.model.device_map + self._weights_path = dir_helper.save_paths.weights + + self.model, self.tokenizer = self._get_merged_model( + dir_helper.save_paths.weights + ) + + def _get_merged_model(self, weights_path: str): + # purge VRAM + torch.cuda.empty_cache() + + # Load from path + dtype = ( + torch.float16 + if self.config.training.training_args.fp16 + else ( + torch.bfloat16 + if self.config.training.training_args.bf16 + else torch.float32 + ) + ) + + self.model = AutoPeftModelForCausalLM.from_pretrained( + weights_path, + torch_dtype=dtype, + device_map=self.device_map, + quantization_config=( + BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump()) + ), + ) + + """TODO: figure out multi-gpu + if self.config.accelerate: + self.model = self.accelerator.prepare(self.model) + """ + + model = self.model.merge_and_unload() + + tokenizer = AutoTokenizer.from_pretrained( + self._weights_path, device_map=self.device_map + ) + + return model, tokenizer + + def infer_all(self): + results = [] + prompts = self.test_dataset["formatted_prompt"] + labels = self.test_dataset[self.label_column] + + # inference loop + for idx, (prompt, label) in enumerate(zip(prompts, labels)): + RichUI.inference_ground_truth_display( + f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label + ) + + try: + result = self.infer_one(prompt) + except: + continue + results.append((prompt, label, result)) + + # TODO: seperate this into another method + header = ["Prompt", "Ground Truth", "Predicted"] + os.makedirs(self.save_dir, exist_ok=True) + with open(self.save_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in results: + writer.writerow(row) + + def infer_one(self, prompt: str) -> str: + input_ids = self.tokenizer( + prompt, return_tensors="pt", truncation=True + ).input_ids.cuda() + + # stream processor + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + decode_kwargs={"skip_special_tokens": True}, + timeout=60, # 60 sec timeout for generation; to handle OOM errors + ) + + generation_kwargs = dict( + input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump() + ) + + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + result = Text() + with RichUI.inference_stream_display(result) as live: + for new_text in streamer: + result.append(new_text) + live.update(result) + + return str(result) diff --git a/src/pydantic_models/__init__.py b/src/pydantic_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pydantic_models/config_model.py b/src/pydantic_models/config_model.py new file mode 100644 index 0000000..375dc45 --- /dev/null +++ b/src/pydantic_models/config_model.py @@ -0,0 +1,219 @@ +from typing import Literal, Union, List, Dict, Optional +from pydantic import BaseModel, FilePath, validator, Field + +from huggingface_hub.utils import validate_repo_id + +import torch + +# TODO: Refactor this into multiple files... +HfModelPath = str + + +class DataConfig(BaseModel): + file_type: Literal["json", "csv", "huggingface"] = Field( + None, description="File type" + ) + path: Union[FilePath, HfModelPath] = Field( + None, description="Path to the file or HuggingFace model" + ) + prompt: str = Field( + None, description="Prompt for the model. Use {} brackets for column name" + ) + prompt_stub: str = Field( + None, + description="Stub for the prompt; this is injected during training. Use {} brackets for column name", + ) + train_size: Optional[Union[float, int]] = Field( + 0.9, + description="Size of the training set; float for proportion and int for # of examples", + ) + test_size: Optional[Union[float, int]] = Field( + 0.1, + description="Size of the test set; float for proportion and int for # of examples", + ) + train_test_split_seed: int = Field( + 42, + description="Seed used in the train test split. This is used to ensure that the train and test sets are the same across runs", + ) + + # @validator("path") + # def validate_path(cls, v, values, **kwargs): + # if "file_type" in values and values["file_type"] == "huggingface": + # if not validate_repo_id(v): + # raise ValueError("Invalid HuggingFace dataset path") + # return v + + +class BitsAndBytesConfig(BaseModel): + load_in_8bit: Optional[bool] = Field( + False, description="Enable 8-bit quantization with LLM.int8()" + ) + llm_int8_threshold: Optional[float] = Field( + 6.0, description="Outlier threshold for outlier detection in 8-bit quantization" + ) + llm_int8_skip_modules: Optional[List[str]] = Field( + None, description="List of modules that we do not want to convert in 8-bit" + ) + llm_int8_enable_fp32_cpu_offload: Optional[bool] = Field( + False, + description="Enable splitting model parts between int8 on GPU and fp32 on CPU", + ) + llm_int8_has_fp16_weight: Optional[bool] = Field( + False, description="Run LLM.int8() with 16-bit main weights" + ) + + load_in_4bit: Optional[bool] = Field( + True, + description="Enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from bitsandbytes", + ) + bnb_4bit_compute_dtype: Optional[str] = Field( + torch.bfloat16, description="Computational type for 4-bit quantization" + ) + bnb_4bit_quant_type: Optional[str] = Field( + "nf4", description="Quantization data type in the bnb.nn.Linear4Bit layers" + ) + bnb_4bit_use_double_quant: Optional[bool] = Field( + True, + description="Enable nested quantization where the quantization constants from the first quantization are quantized again", + ) + + +class ModelConfig(BaseModel): + hf_model_ckpt: Optional[str] = Field( + "NousResearch/Llama-2-7b-hf", + description="Path to the model (huggingface repo or local path)", + ) + device_map: Optional[str] = Field( + "auto", description="device onto which to load the model" + ) + + quantize: Optional[bool] = Field(False, description="Flag to enable quantization") + bitsandbytes: BitsAndBytesConfig = Field( + None, description="Bits and Bytes configuration" + ) + + # @validator("hf_model_ckpt") + # def validate_model(cls, v, **kwargs): + # if not validate_repo_id(v): + # raise ValueError("Invalid HuggingFace dataset path") + # return v + + @validator("quantize") + def set_bitsandbytes_to_none_if_no_quantization(cls, v, values, **kwargs): + if v is False: + values["bitsandbytes"] = None + return v + + @validator("device_map") + def set_device_map_to_none(cls, v, values, **kwargs): + if v.lower() == "none": + return None + return v + + +class LoraConfig(BaseModel): + r: Optional[int] = Field(8, description="Lora rank") + task_type: Optional[str] = Field( + "CAUSAL_LM", description="Base Model task type during training" + ) + + lora_alpha: Optional[int] = Field( + 16, description="The alpha parameter for Lora scaling" + ) + bias: Optional[str] = Field( + "none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'" + ) + lora_dropout: Optional[float] = Field( + 0.1, description="The dropout probability for Lora layers" + ) + target_modules: Optional[List[str]] = Field( + None, description="The names of the modules to apply Lora to" + ) + fan_in_fan_out: Optional[bool] = Field( + False, + description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)", + ) + modules_to_save: Optional[List[str]] = Field( + None, + description="List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint", + ) + layers_to_transform: Optional[Union[List[int], int]] = Field( + None, description="The layer indexes to transform" + ) + layers_pattern: Optional[str] = Field(None, description="The layer pattern name") + # rank_pattern: Optional[Dict[str, int]] = Field( + # {}, description="The mapping from layer names or regexp expression to ranks" + # ) + # alpha_pattern: Optional[Dict[str, int]] = Field( + # {}, description="The mapping from layer names or regexp expression to alphas" + # ) + + +# TODO: Get comprehensive Args! +class TrainingArgs(BaseModel): + num_train_epochs: Optional[int] = Field(1, description="Number of training epochs") + per_device_train_batch_size: Optional[int] = Field( + 1, description="Batch size per training device" + ) + gradient_accumulation_steps: Optional[int] = Field( + 1, description="Number of steps for gradient accumulation" + ) + gradient_checkpointing: Optional[bool] = Field( + True, description="Flag to enable gradient checkpointing" + ) + optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer") + logging_steps: Optional[int] = Field(100, description="Number of logging steps") + learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate") + bf16: Optional[bool] = Field(False, description="Flag to enable bf16") + tf32: Optional[bool] = Field(False, description="Flag to enable tf32") + fp16: Optional[bool] = Field(False, description="Flag to enable fp16") + max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm") + warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio") + lr_scheduler_type: Optional[str] = Field( + "constant", description="Learning rate scheduler type" + ) + + +# TODO: Get comprehensive Args! +class SftArgs(BaseModel): + max_seq_length: Optional[int] = Field(None, description="Maximum sequence length") + neftune_noise_alpha: Optional[float] = Field( + None, + description="If not None, this will activate NEFTune noise embeddings. This can drastically improve model performance for instruction fine-tuning.", + ) + + +class TrainingConfig(BaseModel): + training_args: TrainingArgs + sft_args: SftArgs + + +# TODO: Get comprehensive Args! +class InferenceConfig(BaseModel): + max_new_tokens: Optional[int] = Field(None, description="Maximum new tokens") + use_cache: Optional[bool] = Field(True, description="Flag to enable cache usage") + do_sample: Optional[bool] = Field(True, description="Flag to enable sampling") + top_p: Optional[float] = Field(1.0, description="Top p value") + temperature: Optional[float] = Field(0.1, description="Temperature value") + epsilon_cutoff: Optional[float] = Field(0.0, description="epsilon cutoff value") + eta_cutoff: Optional[float] = Field(0.0, description="eta cutoff value") + top_k: Optional[int] = Field(50, description="top-k sampling") + + +class AblationConfig(BaseModel): + use_ablate: Optional[bool] = Field(False, description="Flag to enable ablation") + study_name: Optional[str] = Field("ablation", description="Name of the study") + + +class Config(BaseModel): + save_dir: Optional[str] = Field("./experiments", description="Folder to save to") + ablation: AblationConfig + accelerate: Optional[bool] = Field( + False, + description="set to True if you want to use multi-gpu training; then launch with `accelerate launch --config_file ./accelerate_config toolkit.py`", + ) + data: DataConfig + model: ModelConfig + lora: LoraConfig + training: TrainingConfig + inference: InferenceConfig diff --git a/src/qa/__init__.py b/src/qa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/qa/qa.py b/src/qa/qa.py new file mode 100644 index 0000000..ce9275f --- /dev/null +++ b/src/qa/qa.py @@ -0,0 +1,73 @@ +from abc import ABC, abstractmethod +from typing import Union, List, Tuple, Dict +import pandas as pd +from toolkit.src.ui.rich_ui import RichUI +import statistics + + +class LLMQaTest(ABC): + @property + @abstractmethod + def test_name(self) -> str: + pass + + @abstractmethod + def get_metric( + self, prompt: str, grount_truth: str, model_pred: str + ) -> Union[float, int, bool]: + pass + + +class LLMTestSuite: + def __init__( + self, + tests: List[LLMQaTest], + prompts: List[str], + ground_truths: List[str], + model_preds: List[str], + ) -> None: + + self.tests = tests + self.prompts = prompts + self.ground_truths = ground_truths + self.model_preds = model_preds + + self.test_results = {} + + def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]: + test_results = {} + for test in zip(self.tests): + metrics = [] + for prompt, ground_truth, model_pred in zip( + self.prompts, self.ground_truths, self.model_preds + ): + metrics.append(test.get_metric(prompt, ground_truth, model_pred)) + test_results[test.test_name] = metrics + + self.test_results = test_results + return test_results + + @property + def test_results(self): + return self.test_results if self.test_results else self.run_tests() + + def print_test_results(self): + result_dictionary = self.test_results() + column_data = { + key: [value for value in result_dictionary[key]] + for key in result_dictionary + } + mean_values = {key: statistics.mean(column_data[key]) for key in column_data} + median_values = { + key: statistics.median(column_data[key]) for key in column_data + } + stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data} + # Use the RichUI class to display the table + RichUI.display_table( + result_dictionary, mean_values, median_values, stdev_values + ) + + def save_test_results(self, path: str): + # TODO: save these! + resultant_dataframe = pd.DataFrame(self.test_results()) + resultant_dataframe.to_csv(path, index=False) diff --git a/src/qa/qa_tests.py b/src/qa/qa_tests.py new file mode 100644 index 0000000..a22ece7 --- /dev/null +++ b/src/qa/qa_tests.py @@ -0,0 +1,162 @@ +from src.qa.qa import LLMQaTest +from typing import Union, List, Tuple, Dict +import torch +from transformers import DistilBertModel, DistilBertTokenizer +import nltk +import numpy as np +from rouge_score import rouge_scorer +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +from nltk import pos_tag + +model_name = "distilbert-base-uncased" +tokenizer = DistilBertTokenizer.from_pretrained(model_name) +model = DistilBertModel.from_pretrained(model_name) + +nltk.download("stopwords") +nltk.download("punkt") +nltk.download("averaged_perceptron_tagger") + + +class LengthTest(LLMQaTest): + @property + def test_name(self) -> str: + return "Summary Length Test" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> Union[float, int, bool]: + return abs(len(ground_truth) - len(model_prediction)) + + +class JaccardSimilarityTest(LLMQaTest): + @property + def test_name(self) -> str: + return "Jaccard Similarity" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> Union[float, int, bool]: + set_ground_truth = set(ground_truth.lower()) + set_model_prediction = set(model_prediction.lower()) + + intersection_size = len(set_ground_truth.intersection(set_model_prediction)) + union_size = len(set_ground_truth.union(set_model_prediction)) + + similarity = intersection_size / union_size if union_size != 0 else 0 + return similarity + + +class DotProductSimilarityTest(LLMQaTest): + @property + def test_name(self) -> str: + return "Semantic Similarity" + + def _encode_sentence(self, sentence): + tokens = tokenizer(sentence, return_tensors="pt") + with torch.no_grad(): + outputs = model(**tokens) + return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> Union[float, int, bool]: + embedding_ground_truth = self._encode_sentence(ground_truth) + embedding_model_prediction = self._encode_sentence(model_prediction) + dot_product_similarity = np.dot( + embedding_ground_truth, embedding_model_prediction + ) + return dot_product_similarity + + +class RougeScoreTest(LLMQaTest): + @property + def test_name(self) -> str: + return "Rouge Score" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> Union[float, int, bool]: + scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) + scores = scorer.score(model_prediction, ground_truth) + return float(scores["rouge1"].precision) + + +class WordOverlapTest(LLMQaTest): + @property + def test_name(self) -> str: + return "Word Overlap Test" + + def _remove_stopwords(self, text: str) -> str: + stop_words = set(stopwords.words("english")) + word_tokens = word_tokenize(text) + filtered_text = [word for word in word_tokens if word.lower() not in stop_words] + return " ".join(filtered_text) + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> Union[float, int, bool]: + cleaned_model_prediction = self._remove_stopwords(model_prediction) + cleaned_ground_truth = self._remove_stopwords(ground_truth) + + words_model_prediction = set(cleaned_model_prediction.split()) + words_ground_truth = set(cleaned_ground_truth.split()) + + common_words = words_model_prediction.intersection(words_ground_truth) + overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100 + return overlap_percentage + + +class PosCompositionTest(LLMQaTest): + def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float: + words = word_tokenize(text) + tags = pos_tag(words) + pos_words = [word for word, tag in tags if tag in pos_tags] + total_words = len(text.split(" ")) + return round(len(pos_words) / total_words, 2) + + +class VerbPercent(PosCompositionTest): + @property + def test_name(self) -> str: + return "Verb Composition" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> float: + return self._get_pos_percent( + model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"] + ) + + +class AdjectivePercent(PosCompositionTest): + @property + def test_name(self) -> str: + return "Adjective Composition" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> float: + return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"]) + + +class NounPercent(PosCompositionTest): + @property + def test_name(self) -> str: + return "Noun Composition" + + def get_metric( + self, prompt: str, ground_truth: str, model_prediction: str + ) -> float: + return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"]) + + +# Instantiate tests +# length_test = LengthTest() +# jaccard_similarity_test = JaccardSimilarityTest() +# dot_product_similarity_test = DotProductSimilarityTest() +# rouge_score_test = RougeScoreTest() +# word_overlap_test = WordOverlapTest() +# verb_percent_test = VerbPercent() +# adjective_percent_test = AdjectivePercent() +# noun_percent_test = NounPercent() diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ui/rich_ui.py b/src/ui/rich_ui.py new file mode 100644 index 0000000..10ab6ed --- /dev/null +++ b/src/ui/rich_ui.py @@ -0,0 +1,214 @@ +from datasets import Dataset + +from rich.console import Console +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table +from rich.live import Live +from rich.text import Text + +from src.ui.ui import UI +from src.utils.rich_print_utils import inject_example_to_rich_layout + +console = Console() + + +class StatusContext: + def __init__(self, console, message, spinner): + self.console = console + self.message = message + self.spinner = spinner + + def __enter__(self): + self.task = self.console.status(self.message, spinner=self.spinner) + self.task.__enter__() # Manually enter the console status context + return self # This allows you to use variables from this context if needed + + def __exit__(self, exc_type, exc_val, exc_tb): + self.task.__exit__( + exc_type, exc_val, exc_tb + ) # Cleanly exit the console status context + + +class LiveContext: + def __init__(self, text: Text, refresh_per_second=4, vertical_overflow="visible"): + self.console = console + self.text = text + self.refresh_per_second = refresh_per_second + self.vertical_overflow = vertical_overflow + + def __enter__(self): + self.task = Live( + self.text, + refresh_per_second=self.refresh_per_second, + vertical_overflow=self.vertical_overflow, + ) + self.task.__enter__() # Manually enter the console status context + return self # This allows you to use variables from this context if needed + + def __exit__(self, exc_type, exc_val, exc_tb): + self.task.__exit__( + exc_type, exc_val, exc_tb + ) # Cleanly exit the console status context + + def update(self, new_text: Text): + self.task.update(new_text) + + +class RichUI(UI): + """ + DATASET + """ + + # Lifecycle functions + @staticmethod + def before_dataset_creation(): + console.rule("[bold green]Loading Data") + + @staticmethod + def during_dataset_creation(message: str, spinner: str): + return StatusContext(console, message, spinner) + + @staticmethod + def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): + console.print(f"Dataset Saved at {save_dir}") + console.print(f"Post-Split data size:") + console.print(f"Train: {len(train)}") + console.print(f"Test: {len(test)}") + + @staticmethod + def dataset_found(save_dir: str): + console.print(f"Loading formatted dataset from directory {save_dir}") + + # Display functions + @staticmethod + def dataset_display_one_example(train_row: dict, test_row: dict): + layout = Layout() + layout.split_row( + Layout(Panel("Train Sample"), name="train"), + Layout( + Panel("Inference Sample"), + name="inference", + ), + ) + + inject_example_to_rich_layout(layout["train"], "Train Example", train_row) + inject_example_to_rich_layout( + layout["inference"], "Inference Example", test_row + ) + + console.print(layout) + + """ + FINETUNING + """ + + # Lifecycle functions + @staticmethod + def before_finetune(): + console.rule("[bold yellow]:smiley: Finetuning") + + @staticmethod + def on_basemodel_load(checkpoint: str): + console.print(f"Loading {checkpoint}...") + + @staticmethod + def after_basemodel_load(checkpoint: str): + console.print(f"{checkpoint} Loaded :smile:") + + @staticmethod + def during_finetune(): + return StatusContext(console, "Finetuning Model...", "runner") + + @staticmethod + def after_finetune(): + console.print(f"Finetuning complete!") + + @staticmethod + def finetune_found(weights_path: str): + console.print(f"Fine-Tuned Model Found at {weights_path}... skipping training") + + """ + INFERENCE + """ + + # Lifecycle functions + @staticmethod + def before_inference(): + console.rule("[bold pink]:face_with_monocle: Testing") + + @staticmethod + def during_inference(): + pass + + @staticmethod + def after_inference(results_path: str): + console.print(f"Inference Results Saved at {results_path}") + + @staticmethod + def results_found(results_path: str): + console.print(f"Inference Results Found at {results_path}") + + # Display functions + @staticmethod + def inference_ground_truth_display(title: str, prompt: str, label: str): + prompt = prompt.replace("[INST]", "").replace("[/INST]", "") + label = label.replace("[INST]", "").replace("[/INST]", "") + + table = Table(title=title, show_lines=True) + table.add_column("prompt") + table.add_column("ground truth") + table.add_row(prompt, label) + console.print(table) + + @staticmethod + def inference_stream_display(text: Text): + console.print("[bold red]Prediction >") + return LiveContext(text) + + """ + QA + """ + + # Lifecycle functions + @staticmethod + def before_qa(): + pass + + @staticmethod + def during_qa(): + pass + + @staticmethod + def after_qa(): + pass + + @staticmethod + def qa_found(): + pass + + @staticmethod + def qa_display_table( + self, result_dictionary, mean_values, median_values, stdev_values + ): + + # Create a table + table = Table(show_header=True, header_style="bold", title="Test Results") + + # Add columns to the table + table.add_column("Metric", style="cyan") + table.add_column("Mean", style="magenta") + table.add_column("Median", style="green") + table.add_column("Standard Deviation", style="yellow") + + # Add data rows to the table + for key in result_dictionary: + table.add_row( + key, + f"{mean_values[key]:.4f}", + f"{median_values[key]:.4f}", + f"{stdev_values[key]:.4f}", + ) + + # Print the table + console.print(table) diff --git a/src/ui/ui.py b/src/ui/ui.py new file mode 100644 index 0000000..59d5997 --- /dev/null +++ b/src/ui/ui.py @@ -0,0 +1,116 @@ +from abc import ABC, abstractstaticmethod + +from datasets import Dataset +from rich.text import Text + + +class UI(ABC): + """ + DATASET + """ + + # Lifecycle functions + @abstractstaticmethod + def before_dataset_creation(): + pass + + @abstractstaticmethod + def during_dataset_creation(message: str, spinner: str): + pass + + @abstractstaticmethod + def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): + pass + + @abstractstaticmethod + def dataset_found(save_dir: str): + pass + + # Display functions + @abstractstaticmethod + def dataset_display_one_example(train_row: dict, test_row: dict): + pass + + """ + FINETUNING + """ + + # Lifecycle functions + @abstractstaticmethod + def before_finetune(): + pass + + @abstractstaticmethod + def on_basemodel_load(checkpoint: str): + pass + + @abstractstaticmethod + def after_basemodel_load(checkpoint: str): + pass + + @abstractstaticmethod + def during_finetune(): + pass + + @abstractstaticmethod + def after_finetune(): + pass + + @abstractstaticmethod + def finetune_found(weights_path: str): + pass + + """ + INFERENCE + """ + + # Lifecycle functions + @abstractstaticmethod + def before_inference(): + pass + + @abstractstaticmethod + def during_inference(): + pass + + @abstractstaticmethod + def after_inference(results_path: str): + pass + + @abstractstaticmethod + def results_found(results_path: str): + pass + + # Display functions + @abstractstaticmethod + def inference_ground_truth_display(title: str, prompt: str, label: str): + pass + + @abstractstaticmethod + def inference_stream_display(text: Text): + pass + + """ + QA + """ + + # Lifecycle functions + @abstractstaticmethod + def before_qa(cls): + pass + + @abstractstaticmethod + def during_qa(cls): + pass + + @abstractstaticmethod + def after_qa(cls): + pass + + @abstractstaticmethod + def qa_found(cls): + pass + + @abstractstaticmethod + def qa_display_table(cls): + pass diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/ablation_utils.py b/src/utils/ablation_utils.py new file mode 100644 index 0000000..37d9d80 --- /dev/null +++ b/src/utils/ablation_utils.py @@ -0,0 +1,119 @@ +import copy +import itertools +from typing import List, Type, Any, Dict, Optional, Union, Tuple +from typing import get_args, get_origin, get_type_hints + +import yaml + + +# TODO: organize this a little bit. It's a bit of a mess rn. + +""" +Helper functions to create multiple valid configs based on ablation (i.e. list of values) +fron a single config yaml +""" + + +def get_types_from_dict( + source_dict: dict, root="", type_dict={} +) -> Dict[str, Tuple[type, type]]: + for key, val in source_dict.items(): + if type(val) is not dict: + attr = f"{root}.{key}" if root else key + tp = ( + (type(val), None) + if type(val) is not list + else (type(val), type(val[0])) + ) + type_dict[attr] = tp + else: + join_array = [root, key] if root else [key] + new_root = ".".join(join_array) + get_types_from_dict(val, new_root, type_dict) + + return type_dict + + +def get_annotation(key: str, base_model): + keys = key.split(".") + model = base_model + for key in keys: + model = model.__annotations__[key] + + return model + + +def get_model_field_type(annotation): + origin = get_origin(annotation) + if not origin: + return annotation + if origin is Union: + annotations = get_args(annotation)[0] + return get_model_field_type(annotations) + if origin is list: + return list + + +def get_data_with_key(key, data): + keys = key.split(".") + for key in keys: + data = data[key] + return data + + +def validate_and_get_ablations(type_dict, data, base_model): + ablations = {} + for key, (tp, subtype) in type_dict.items(): + annotation = get_annotation(key, base_model) + model_field_type = get_model_field_type(annotation) + if (model_field_type is list) and (tp is list) and (subtype is list): + # Handle both list and list of lists + ablations[key] = get_data_with_key(key, data) + elif model_field_type is not list and tp is list: + # Handle single-level lists + ablations[key] = get_data_with_key(key, data) + + return ablations + + +def patch_with_permutation(old_dict, permutation_dict): + # Create a deep copy of the old dictionary to avoid modifying the original + updated_dict = copy.deepcopy(old_dict) + + # Iterate over each item in the permutation dictionary + for dot_key, new_value in permutation_dict.items(): + # Split the dot-joined key into individual keys + keys = dot_key.split(".") + + # Start from the root of the updated dictionary + current_level = updated_dict + + # Traverse to the second-to-last key in the nested dictionary + for key in keys[:-1]: + current_level = current_level[key] + + # Update the value at the final key + current_level[keys[-1]] = new_value + + return updated_dict + + +def generate_permutations(yaml_dict, model): + type_dict = get_types_from_dict(yaml_dict) + + ablations = validate_and_get_ablations(type_dict, yaml_dict, model) + + # get permutations + lists = list(ablations.values()) + permutations = list(itertools.product(*lists)) + + permutation_dicts = [] + for perm in permutations: + new_dict = dict(zip(ablations.keys(), perm)) + permutation_dicts.append(new_dict) + + new_dicts = [] + for perm in permutation_dicts: + new_dicts.append(patch_with_permutation(yaml_dict, perm)) + + return new_dicts diff --git a/src/utils/rich_print_utils.py b/src/utils/rich_print_utils.py new file mode 100644 index 0000000..371f742 --- /dev/null +++ b/src/utils/rich_print_utils.py @@ -0,0 +1,55 @@ +from rich.panel import Panel +from rich.layout import Layout +from rich.text import Text +from rich.table import Table + + +def inject_example_to_rich_layout(layout: Layout, layout_name: str, example: dict): + example = example.copy() + + # Crate Table + table = Table(expand=True) + colors = [ + "navy_blue", + "dark_green", + "spring_green3", + "turquoise2", + "cyan", + "blue_violet", + "royal_blue1", + "steel_blue1", + "chartreuse1", + "deep_pink4", + "plum2", + "red", + ] + + # Crate Formatted Text + formatted = example.pop("formatted_prompt", None) + formatted_text = Text(formatted) + + for key, c in zip(example.keys(), colors): + table.add_column(key, style=c) + + tgt_text = example[key] + start_idx = formatted.find(tgt_text) + formatted_text.stylize(f"bold {c}", start_idx, start_idx + len(tgt_text)) + + table.add_row(*example.values()) + + layout.split_column( + Layout( + Panel( + table, + title=f"{layout_name} - Raw", + title_align="left", + ) + ), + Layout( + Panel( + formatted_text, + title=f"{layout_name} - Formatted", + title_align="left", + ) + ), + ) diff --git a/src/utils/save_utils.py b/src/utils/save_utils.py new file mode 100644 index 0000000..493b3c9 --- /dev/null +++ b/src/utils/save_utils.py @@ -0,0 +1,85 @@ +""" +Helper functions to help managing saving and loading of experiments: + 1. Generate save directory name + 2. Check if files are present at various experiment stages +""" + +import shutil +import os +from os.path import exists +import yaml + +import re +import hashlib +from functools import cached_property +from dataclasses import dataclass + +from sqids import Sqids + +from src.pydantic_models.config_model import Config + +NUM_MD5_DIGITS_FOR_SQIDS = 5 # TODO: maybe move consts to a dedicated folder + + +@dataclass +class DirectoryList: + save_dir: str + config_hash: str + + @property + def experiment(self) -> str: + return os.path.join(self.save_dir, self.config_hash) + + @property + def config(self) -> str: + return os.path.join(self.experiment, "config") + + @property + def dataset(self) -> str: + return os.path.join(self.experiment, "dataset") + + @property + def weights(self) -> str: + return os.path.join(self.experiment, "weights") + + @property + def results(self) -> str: + return os.path.join(self.experiment, "results") + + @property + def qa(self) -> str: + return os.path.join(self.experiment, "qa") + + +class DirectoryHelper: + def __init__(self, config_path: str, config: Config): + self.config_path: str = config_path + self.config: Config = config + self.sqids: Sqids = Sqids() + self.save_paths: DirectoryList = self._get_directory_state() + + os.makedirs(self.save_paths.experiment, exist_ok=True) + if not exists(self.save_paths.config): + self.save_config() + + @cached_property + def config_hash(self) -> str: + config_str = self.config.model_dump_json() + config_str = re.sub(r"\s", "", config_str) + hash = hashlib.md5(config_str.encode()).digest() + return self.sqids.encode(hash[:NUM_MD5_DIGITS_FOR_SQIDS]) + + def _get_directory_state(self) -> DirectoryList: + save_dir = ( + self.config.save_dir + if not self.config.ablation.use_ablate + else os.path.join(self.config.save_dir, self.config.ablation.study_name) + ) + return DirectoryList(save_dir, self.config_hash) + + def save_config(self) -> None: + os.makedirs(self.save_paths.config, exist_ok=True) + model_dict = self.config.model_dump() + + with open(os.path.join(self.save_paths.config, "config.yml"), "w") as file: + yaml.dump(model_dict, file) diff --git a/toolkit.py b/toolkit.py new file mode 100644 index 0000000..de996b8 --- /dev/null +++ b/toolkit.py @@ -0,0 +1,108 @@ +from os import listdir +from os.path import join, exists +import yaml +import logging + +from transformers import utils as hf_utils +from pydantic import ValidationError +import torch + +from src.pydantic_models.config_model import Config +from src.data.dataset_generator import DatasetGenerator +from src.utils.save_utils import DirectoryHelper +from src.utils.ablation_utils import generate_permutations +from src.finetune.lora import LoRAFinetune +from src.inference.lora import LoRAInference +from src.ui.rich_ui import RichUI + +hf_utils.logging.set_verbosity_error() +torch._logging.set_logs(all=logging.CRITICAL) + + +def run_one_experiment(config: Config) -> None: + dir_helper = DirectoryHelper(config_path, config) + + # Loading Data ------------------------------- + RichUI.before_dataset_creation() + + with RichUI.during_dataset_creation("Injecting Values into Prompt", "monkey"): + dataset_generator = DatasetGenerator(**config.data.model_dump()) + + train_columns = dataset_generator.train_columns + test_column = dataset_generator.test_column + + dataset_path = dir_helper.save_paths.dataset + if not exists(dataset_path): + train, test = dataset_generator.get_dataset() + dataset_generator.save_dataset(dataset_path) + else: + RichUI.dataset_found(dataset_path) + train, test = dataset_generator.load_dataset_from_pickle(dataset_path) + + RichUI.dataset_display_one_example(train[0], test[0]) + RichUI.after_dataset_creation(dataset_path, train, test) + + # Loading Model ------------------------------- + RichUI.before_finetune() + + weights_path = dir_helper.save_paths.weights + + # model_loader = ModelLoader(config, console, dir_helper) + if not exists(weights_path) or not listdir(weights_path): + finetuner = LoRAFinetune(config, dir_helper) + with RichUI.during_finetune(): + finetuner.finetune(train) + finetuner.save_model() + RichUI.after_finetune() + else: + RichUI.finetune_found(weights_path) + + # Inference ------------------------------- + RichUI.before_inference() + results_path = dir_helper.save_paths.results + results_file_path = join(dir_helper.save_paths.results, "results.csv") + if not exists(results_path) or exists(results_file_path): + inference_runner = LoRAInference( + test, test_column, config, dir_helper + ).infer_all() + RichUI.after_inference(results_path) + else: + RichUI.inference_found(results_path) + + # QA ------------------------------- + # console.rule("[bold blue]:thinking_face: Running LLM Unit Tests") + # qa_path = dir_helper.save_paths.qa + # if not exists(qa_path) or not listdir(qa_path): + # # TODO: Instantiate unit test classes + # # TODO: Load results.csv + # # TODO: Run Unit Tests + # # TODO: Save Unit Test Results + # pass + + +if __name__ == "__main__": + config_path = "./config.yml" # TODO: parameterize this + + # Load YAML config + with open(config_path, "r") as file: + config = yaml.safe_load(file) + configs = ( + generate_permutations(config, Config) + if config.get("ablation", {}).get("use_ablate", False) + else [config] + ) + for config in configs: + try: + config = Config(**config) + # validate data with pydantic + except ValidationError as e: + print(e.json()) + + dir_helper = DirectoryHelper(config_path, config) + + # Reload config from saved config + with open(join(dir_helper.save_paths.config, "config.yml"), "r") as file: + config = yaml.safe_load(file) + config = Config(**config) + + run_one_experiment(config)