diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a9e08a0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/README.md b/README.md index 9a1c569..bd8c720 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ LLM Finetuning toolkit is a config-based CLI tool for launching a series of LLM See poetry documentation page for poetry [installation instructions](https://python-poetry.org/docs/#installation) ```shell - poetry install + poetry install --without dev ``` ### [Option 3] pip @@ -255,3 +255,10 @@ If you would like to contribute to this project, we recommend following the "for 5. Submit a **Pull request** so that we can review your changes NOTE: Be sure to merge the latest from "upstream" before making a pull request! + +### Setting Up Repo for Development + +- We recommend using `poetry` to manage dependency +- Install deps via `poetry install` +- Enter virtual environment with `poetry shell` +- Install pre-commit hooks using `pre-commit install` diff --git a/poetry.lock b/poetry.lock index c65c484..52e69ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -291,52 +291,6 @@ files = [ [package.dependencies] scipy = "*" -[[package]] -name = "black" -version = "24.3.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, - {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, - {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, - {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, - {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, - {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, - {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, - {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, - {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, - {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, - {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, - {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, - {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, - {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, - {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, - {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, - {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, - {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, - {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, - {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, - {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, - {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "brotli" version = "1.1.0" @@ -543,6 +497,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -752,6 +717,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "distro" version = "1.9.0" @@ -1124,6 +1100,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jed torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "identify" +version = "2.5.35" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"}, + {file = "identify-2.5.35.tar.gz", hash = "sha256:10a7ca245cfcd756a554a7288159f72ff105ad233c7c4b9c6f0f4d108f5f6791"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.6" @@ -1774,6 +1764,20 @@ plot = ["matplotlib"] tgrep = ["pyparsing"] twitter = ["twython"] +[[package]] +name = "nodeenv" +version = "1.8.0" +description = "Node.js virtual environment builder" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +files = [ + {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, + {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, +] + +[package.dependencies] +setuptools = "*" + [[package]] name = "numpy" version = "1.26.4" @@ -2082,17 +2086,6 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "peft" version = "0.8.2" @@ -2151,6 +2144,24 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "pre-commit" +version = "3.7.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prompt-toolkit" version = "3.0.43" @@ -3026,6 +3037,32 @@ nltk = "*" numpy = "*" six = ">=1.14.0" +[[package]] +name = "ruff" +version = "0.3.5" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"}, + {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"}, + {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"}, + {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"}, + {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"}, +] + [[package]] name = "safetensors" version = "0.4.2" @@ -4102,6 +4139,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.25.1" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"}, + {file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "wandb" version = "0.16.5" @@ -4378,4 +4435,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9, <=3.12" -content-hash = "32b85b0200f6dab57cb9f4c57fb47cd28bbf41e6077f88428331b18239e0f7e1" +content-hash = "e3508701677147be5bbb94b2ce7e9196ce5ebb36bb2ba580b2db4fdde531ce7d" diff --git a/pyproject.toml b/pyproject.toml index 76e72a6..6c2d233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,31 @@ shellingham = "^1.5.4" [tool.poetry.group.dev.dependencies] -black = "^24.3.0" +pre-commit = "~3.7.0" +ruff = "~0.3.5" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + + +[tool.ruff] +lint.ignore = ["C901", "E501", "E741", "F402", "F823" ] +lint.select = ["C", "E", "F", "I", "W"] +line-length = 119 +exclude = [ + "llama2", + "mistral", +] + + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["llmtune"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + diff --git a/src/data/dataset_generator.py b/src/data/dataset_generator.py index 7a05c0d..51f6a5b 100644 --- a/src/data/dataset_generator.py +++ b/src/data/dataset_generator.py @@ -1,14 +1,11 @@ import os -from os.path import join, exists +import pickle +import re from functools import partial +from os.path import exists, join from typing import Tuple, Union -import pickle -import re from datasets import Dataset -from rich.console import Console -from rich.layout import Layout -from rich.panel import Panel from src.data.ingestor import Ingestor, get_ingestor @@ -64,12 +61,8 @@ def _format_one_prompt(self, example, is_test: bool = False): return example def _format_prompts(self): - self.dataset["train"] = self.dataset["train"].map( - partial(self._format_one_prompt, is_test=False) - ) - self.dataset["test"] = self.dataset["test"].map( - partial(self._format_one_prompt, is_test=True) - ) + self.dataset["train"] = self.dataset["train"].map(partial(self._format_one_prompt, is_test=False)) + self.dataset["test"] = self.dataset["test"].map(partial(self._format_one_prompt, is_test=True)) def get_dataset(self) -> Tuple[Dataset, Dataset]: self._train_test_split() diff --git a/src/data/ingestor.py b/src/data/ingestor.py index 227e4d7..3f06c33 100644 --- a/src/data/ingestor.py +++ b/src/data/ingestor.py @@ -1,9 +1,8 @@ +import csv from abc import ABC, abstractmethod -from functools import partial import ijson -import csv -from datasets import Dataset, load_dataset, concatenate_datasets +from datasets import Dataset, concatenate_datasets, load_dataset def get_ingestor(data_type: str): @@ -14,9 +13,7 @@ def get_ingestor(data_type: str): elif data_type == "huggingface": return HuggingfaceIngestor else: - raise ValueError( - f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}" - ) + raise ValueError(f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}") class Ingestor(ABC): diff --git a/src/finetune/finetune.py b/src/finetune/finetune.py index 77e053f..d5b122a 100644 --- a/src/finetune/finetune.py +++ b/src/finetune/finetune.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Union, List, Tuple, Dict class Finetune(ABC): diff --git a/src/finetune/lora.py b/src/finetune/lora.py index c980bdf..2fd766a 100644 --- a/src/finetune/lora.py +++ b/src/finetune/lora.py @@ -1,32 +1,26 @@ -from os.path import join, exists -from typing import Tuple - -import torch +from os.path import join import bitsandbytes as bnb +import torch from datasets import Dataset -from accelerate import Accelerator +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, +) from transformers import ( - AutoTokenizer, AutoModelForCausalLM, - BitsAndBytesConfig, - TrainingArguments, AutoTokenizer, + BitsAndBytesConfig, ProgressCallback, -) -from peft import ( - prepare_model_for_kbit_training, - get_peft_model, - LoraConfig, + TrainingArguments, ) from trl import SFTTrainer -from rich.console import Console - -from src.pydantic_models.config_model import Config -from src.utils.save_utils import DirectoryHelper from src.finetune.finetune import Finetune +from src.pydantic_models.config_model import Config from src.ui.rich_ui import RichUI +from src.utils.save_utils import DirectoryHelper class LoRAFinetune(Finetune): @@ -100,9 +94,7 @@ def _inject_lora(self): self.model = get_peft_model(self.model, self._lora_config) if not self.config.accelerate: - self.optimizer = bnb.optim.Adam8bit( - self.model.parameters(), lr=self._training_args.learning_rate - ) + self.optimizer = bnb.optim.Adam8bit(self.model.parameters(), lr=self._training_args.learning_rate) self.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer) if self.config.accelerate: self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( @@ -133,7 +125,7 @@ def finetune(self, train_dataset: Dataset): **self._sft_args.model_dump(), ) - trainer_stats = self._trainer.train() + self._trainer.train() def save_model(self) -> None: self._trainer.model.save_pretrained(self._weights_path) diff --git a/src/inference/inference.py b/src/inference/inference.py index be23a80..24a2bfb 100644 --- a/src/inference/inference.py +++ b/src/inference/inference.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Union, List, Tuple, Dict class Inference(ABC): diff --git a/src/inference/lora.py b/src/inference/lora.py index 4e2a631..fee42b9 100644 --- a/src/inference/lora.py +++ b/src/inference/lora.py @@ -1,23 +1,18 @@ +import csv import os from os.path import join from threading import Thread -import csv -from transformers import TextIteratorStreamer -from rich.console import Console -from rich.table import Table -from rich.live import Live -from rich.text import Text +import torch from datasets import Dataset -from transformers import AutoTokenizer, BitsAndBytesConfig from peft import AutoPeftModelForCausalLM -import torch - +from rich.text import Text +from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer -from src.pydantic_models.config_model import Config -from src.utils.save_utils import DirectoryHelper from src.inference.inference import Inference +from src.pydantic_models.config_model import Config from src.ui.rich_ui import RichUI +from src.utils.save_utils import DirectoryHelper # TODO: Add type hints please! @@ -38,9 +33,7 @@ def __init__( self.device_map = self.config.model.device_map self._weights_path = dir_helper.save_paths.weights - self.model, self.tokenizer = self._get_merged_model( - dir_helper.save_paths.weights - ) + self.model, self.tokenizer = self._get_merged_model(dir_helper.save_paths.weights) def _get_merged_model(self, weights_path: str): # purge VRAM @@ -50,20 +43,14 @@ def _get_merged_model(self, weights_path: str): dtype = ( torch.float16 if self.config.training.training_args.fp16 - else ( - torch.bfloat16 - if self.config.training.training_args.bf16 - else torch.float32 - ) + else (torch.bfloat16 if self.config.training.training_args.bf16 else torch.float32) ) self.model = AutoPeftModelForCausalLM.from_pretrained( weights_path, torch_dtype=dtype, device_map=self.device_map, - quantization_config=( - BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump()) - ), + quantization_config=(BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())), ) """TODO: figure out multi-gpu @@ -73,9 +60,7 @@ def _get_merged_model(self, weights_path: str): model = self.model.merge_and_unload() - tokenizer = AutoTokenizer.from_pretrained( - self._weights_path, device_map=self.device_map - ) + tokenizer = AutoTokenizer.from_pretrained(self._weights_path, device_map=self.device_map) return model, tokenizer @@ -86,13 +71,11 @@ def infer_all(self): # inference loop for idx, (prompt, label) in enumerate(zip(prompts, labels)): - RichUI.inference_ground_truth_display( - f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label - ) + RichUI.inference_ground_truth_display(f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label) try: result = self.infer_one(prompt) - except: + except Exception: continue results.append((prompt, label, result)) @@ -106,9 +89,7 @@ def infer_all(self): writer.writerow(row) def infer_one(self, prompt: str) -> str: - input_ids = self.tokenizer( - prompt, return_tensors="pt", truncation=True - ).input_ids.cuda() + input_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() # stream processor streamer = TextIteratorStreamer( @@ -118,9 +99,7 @@ def infer_one(self, prompt: str) -> str: timeout=60, # 60 sec timeout for generation; to handle OOM errors ) - generation_kwargs = dict( - input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump() - ) + generation_kwargs = dict(input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump()) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() diff --git a/src/pydantic_models/config_model.py b/src/pydantic_models/config_model.py index e2f5617..755c60c 100644 --- a/src/pydantic_models/config_model.py +++ b/src/pydantic_models/config_model.py @@ -1,27 +1,21 @@ -from typing import Literal, Union, List, Dict, Optional -from pydantic import BaseModel, FilePath, validator, Field - -from huggingface_hub.utils import validate_repo_id +from typing import List, Literal, Optional, Union import torch +from pydantic import BaseModel, Field, FilePath, validator + # TODO: Refactor this into multiple files... HfModelPath = str + class QaConfig(BaseModel): - llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected") - + llm_tests: Optional[List[str]] = Field([], description="list of tests that needs to be connected") + class DataConfig(BaseModel): - file_type: Literal["json", "csv", "huggingface"] = Field( - None, description="File type" - ) - path: Union[FilePath, HfModelPath] = Field( - None, description="Path to the file or HuggingFace model" - ) - prompt: str = Field( - None, description="Prompt for the model. Use {} brackets for column name" - ) + file_type: Literal["json", "csv", "huggingface"] = Field(None, description="File type") + path: Union[FilePath, HfModelPath] = Field(None, description="Path to the file or HuggingFace model") + prompt: str = Field(None, description="Prompt for the model. Use {} brackets for column name") prompt_stub: str = Field( None, description="Stub for the prompt; this is injected during training. Use {} brackets for column name", @@ -48,9 +42,7 @@ class DataConfig(BaseModel): class BitsAndBytesConfig(BaseModel): - load_in_8bit: Optional[bool] = Field( - False, description="Enable 8-bit quantization with LLM.int8()" - ) + load_in_8bit: Optional[bool] = Field(False, description="Enable 8-bit quantization with LLM.int8()") llm_int8_threshold: Optional[float] = Field( 6.0, description="Outlier threshold for outlier detection in 8-bit quantization" ) @@ -61,9 +53,7 @@ class BitsAndBytesConfig(BaseModel): False, description="Enable splitting model parts between int8 on GPU and fp32 on CPU", ) - llm_int8_has_fp16_weight: Optional[bool] = Field( - False, description="Run LLM.int8() with 16-bit main weights" - ) + llm_int8_has_fp16_weight: Optional[bool] = Field(False, description="Run LLM.int8() with 16-bit main weights") load_in_4bit: Optional[bool] = Field( True, @@ -86,14 +76,10 @@ class ModelConfig(BaseModel): "NousResearch/Llama-2-7b-hf", description="Path to the model (huggingface repo or local path)", ) - device_map: Optional[str] = Field( - "auto", description="device onto which to load the model" - ) + device_map: Optional[str] = Field("auto", description="device onto which to load the model") quantize: Optional[bool] = Field(False, description="Flag to enable quantization") - bitsandbytes: BitsAndBytesConfig = Field( - None, description="Bits and Bytes configuration" - ) + bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration") # @validator("hf_model_ckpt") # def validate_model(cls, v, **kwargs): @@ -116,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs): class LoraConfig(BaseModel): r: Optional[int] = Field(8, description="Lora rank") - task_type: Optional[str] = Field( - "CAUSAL_LM", description="Base Model task type during training" - ) + task_type: Optional[str] = Field("CAUSAL_LM", description="Base Model task type during training") - lora_alpha: Optional[int] = Field( - 16, description="The alpha parameter for Lora scaling" - ) - bias: Optional[str] = Field( - "none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'" - ) - lora_dropout: Optional[float] = Field( - 0.1, description="The dropout probability for Lora layers" - ) - target_modules: Optional[List[str]] = Field( - None, description="The names of the modules to apply Lora to" - ) + lora_alpha: Optional[int] = Field(16, description="The alpha parameter for Lora scaling") + bias: Optional[str] = Field("none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'") + lora_dropout: Optional[float] = Field(0.1, description="The dropout probability for Lora layers") + target_modules: Optional[List[str]] = Field(None, description="The names of the modules to apply Lora to") fan_in_fan_out: Optional[bool] = Field( False, description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)", @@ -140,9 +116,7 @@ class LoraConfig(BaseModel): None, description="List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint", ) - layers_to_transform: Optional[Union[List[int], int]] = Field( - None, description="The layer indexes to transform" - ) + layers_to_transform: Optional[Union[List[int], int]] = Field(None, description="The layer indexes to transform") layers_pattern: Optional[str] = Field(None, description="The layer pattern name") # rank_pattern: Optional[Dict[str, int]] = Field( # {}, description="The mapping from layer names or regexp expression to ranks" @@ -155,15 +129,9 @@ class LoraConfig(BaseModel): # TODO: Get comprehensive Args! class TrainingArgs(BaseModel): num_train_epochs: Optional[int] = Field(1, description="Number of training epochs") - per_device_train_batch_size: Optional[int] = Field( - 1, description="Batch size per training device" - ) - gradient_accumulation_steps: Optional[int] = Field( - 1, description="Number of steps for gradient accumulation" - ) - gradient_checkpointing: Optional[bool] = Field( - True, description="Flag to enable gradient checkpointing" - ) + per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device") + gradient_accumulation_steps: Optional[int] = Field(1, description="Number of steps for gradient accumulation") + gradient_checkpointing: Optional[bool] = Field(True, description="Flag to enable gradient checkpointing") optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer") logging_steps: Optional[int] = Field(100, description="Number of logging steps") learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate") @@ -172,9 +140,7 @@ class TrainingArgs(BaseModel): fp16: Optional[bool] = Field(False, description="Flag to enable fp16") max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm") warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio") - lr_scheduler_type: Optional[str] = Field( - "constant", description="Learning rate scheduler type" - ) + lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type") # TODO: Get comprehensive Args! diff --git a/src/qa/qa.py b/src/qa/qa.py index f39c986..a7a7169 100644 --- a/src/qa/qa.py +++ b/src/qa/qa.py @@ -1,9 +1,10 @@ +import statistics from abc import ABC, abstractmethod -from typing import Union, List, Tuple, Dict +from typing import Dict, List, Union + import pandas as pd -from src.ui.rich_ui import RichUI -import statistics -from src.qa.qa_tests import * + +from src.ui.rich_ui import RichUI class LLMQaTest(ABC): @@ -13,11 +14,10 @@ def test_name(self) -> str: pass @abstractmethod - def get_metric( - self, prompt: str, grount_truth: str, model_pred: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]: pass + class QaTestRegistry: registry = {} @@ -27,19 +27,22 @@ def inner_wrapper(wrapped_class): for name in names: cls.registry[name] = wrapped_class return wrapped_class + return inner_wrapper - @classmethod - def create_tests_from_list(cls, test_name: str) -> List[LLMQaTest]: + @classmethod + def create_tests_from_list(cls, test_names: str) -> List[LLMQaTest]: return [cls.create_test(test) for test in test_names] -class LLMTestSuite(): - def __init__(self, - tests:List[LLMQaTest], - prompts:List[str], - ground_truths:List[str], - model_preds:List[str]) -> None: +class LLMTestSuite: + def __init__( + self, + tests: List[LLMQaTest], + prompts: List[str], + ground_truths: List[str], + model_preds: List[str], + ) -> None: self.tests = tests self.prompts = prompts self.ground_truths = ground_truths @@ -51,9 +54,7 @@ def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]: test_results = {} for test in zip(self.tests): metrics = [] - for prompt, ground_truth, model_pred in zip( - self.prompts, self.ground_truths, self.model_preds - ): + for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds): metrics.append(test.get_metric(prompt, ground_truth, model_pred)) test_results[test.test_name] = metrics @@ -66,19 +67,12 @@ def test_results(self): def print_test_results(self): result_dictionary = self.test_results() - column_data = { - key: [value for value in result_dictionary[key]] - for key in result_dictionary - } + column_data = {key: list(result_dictionary[key]) for key in result_dictionary} mean_values = {key: statistics.mean(column_data[key]) for key in column_data} - median_values = { - key: statistics.median(column_data[key]) for key in column_data - } + median_values = {key: statistics.median(column_data[key]) for key in column_data} stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data} # Use the RichUI class to display the table - RichUI.display_table( - result_dictionary, mean_values, median_values, stdev_values - ) + RichUI.display_table(result_dictionary, mean_values, median_values, stdev_values) def save_test_results(self, path: str): # TODO: save these! diff --git a/src/qa/qa_tests.py b/src/qa/qa_tests.py index ce2e59f..df316f4 100644 --- a/src/qa/qa_tests.py +++ b/src/qa/qa_tests.py @@ -1,14 +1,16 @@ -from src.qa.qa import LLMQaTest -from typing import Union, List, Tuple, Dict -import torch -from transformers import DistilBertModel, DistilBertTokenizer +from typing import List, Union + import nltk import numpy as np -from rouge_score import rouge_scorer +import torch +from nltk import pos_tag from nltk.corpus import stopwords from nltk.tokenize import word_tokenize -from nltk import pos_tag -from src.qa.qa import TestRegistry +from rouge_score import rouge_scorer +from transformers import DistilBertModel, DistilBertTokenizer + +from src.qa.qa import LLMQaTest, TestRegistry + model_name = "distilbert-base-uncased" tokenizer = DistilBertTokenizer.from_pretrained(model_name) @@ -18,26 +20,24 @@ nltk.download("punkt") nltk.download("averaged_perceptron_tagger") + @TestRegistry.register("summary_length") class LengthTest(LLMQaTest): @property def test_name(self) -> str: return "summary_length" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: return abs(len(ground_truth) - len(model_prediction)) + @TestRegistry.register("jaccard_similarity") class JaccardSimilarityTest(LLMQaTest): @property def test_name(self) -> str: return "jaccard_similarity" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: set_ground_truth = set(ground_truth.lower()) set_model_prediction = set(model_prediction.lower()) @@ -47,6 +47,7 @@ def get_metric( similarity = intersection_size / union_size if union_size != 0 else 0 return similarity + @TestRegistry.register("dot_product") class DotProductSimilarityTest(LLMQaTest): @property @@ -59,29 +60,25 @@ def _encode_sentence(self, sentence): outputs = model(**tokens) return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: embedding_ground_truth = self._encode_sentence(ground_truth) embedding_model_prediction = self._encode_sentence(model_prediction) - dot_product_similarity = np.dot( - embedding_ground_truth, embedding_model_prediction - ) + dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction) return dot_product_similarity + @TestRegistry.register("rouge_score") class RougeScoreTest(LLMQaTest): @property def test_name(self) -> str: return "rouge_score" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) scores = scorer.score(model_prediction, ground_truth) return float(scores["rouge1"].precision) + @TestRegistry.register("word_overlap") class WordOverlapTest(LLMQaTest): @property @@ -94,9 +91,7 @@ def _remove_stopwords(self, text: str) -> str: filtered_text = [word for word in word_tokens if word.lower() not in stop_words] return " ".join(filtered_text) - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> Union[float, int, bool]: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: cleaned_model_prediction = self._remove_stopwords(model_prediction) cleaned_ground_truth = self._remove_stopwords(ground_truth) @@ -116,18 +111,16 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float: total_words = len(text.split(" ")) return round(len(pos_words) / total_words, 2) + @TestRegistry.register("verb_percent") class VerbPercent(PosCompositionTest): @property def test_name(self) -> str: return "verb_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: - return self._get_pos_percent( - model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"] - ) + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: + return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]) + @TestRegistry.register("adjective_percent") class AdjectivePercent(PosCompositionTest): @@ -135,20 +128,17 @@ class AdjectivePercent(PosCompositionTest): def test_name(self) -> str: return "adjective_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"]) + @TestRegistry.register("noun_percent") class NounPercent(PosCompositionTest): @property def test_name(self) -> str: return "noun_percent" - def get_metric( - self, prompt: str, ground_truth: str, model_prediction: str - ) -> float: + def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"]) diff --git a/src/ui/rich_ui.py b/src/ui/rich_ui.py index 10ab6ed..902d9ac 100644 --- a/src/ui/rich_ui.py +++ b/src/ui/rich_ui.py @@ -1,15 +1,15 @@ from datasets import Dataset - from rich.console import Console from rich.layout import Layout +from rich.live import Live from rich.panel import Panel from rich.table import Table -from rich.live import Live from rich.text import Text from src.ui.ui import UI from src.utils.rich_print_utils import inject_example_to_rich_layout + console = Console() @@ -25,9 +25,7 @@ def __enter__(self): return self # This allows you to use variables from this context if needed def __exit__(self, exc_type, exc_val, exc_tb): - self.task.__exit__( - exc_type, exc_val, exc_tb - ) # Cleanly exit the console status context + self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context class LiveContext: @@ -47,9 +45,7 @@ def __enter__(self): return self # This allows you to use variables from this context if needed def __exit__(self, exc_type, exc_val, exc_tb): - self.task.__exit__( - exc_type, exc_val, exc_tb - ) # Cleanly exit the console status context + self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context def update(self, new_text: Text): self.task.update(new_text) @@ -72,7 +68,7 @@ def during_dataset_creation(message: str, spinner: str): @staticmethod def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): console.print(f"Dataset Saved at {save_dir}") - console.print(f"Post-Split data size:") + console.print("Post-Split data size:") console.print(f"Train: {len(train)}") console.print(f"Test: {len(test)}") @@ -93,9 +89,7 @@ def dataset_display_one_example(train_row: dict, test_row: dict): ) inject_example_to_rich_layout(layout["train"], "Train Example", train_row) - inject_example_to_rich_layout( - layout["inference"], "Inference Example", test_row - ) + inject_example_to_rich_layout(layout["inference"], "Inference Example", test_row) console.print(layout) @@ -122,14 +116,14 @@ def during_finetune(): @staticmethod def after_finetune(): - console.print(f"Finetuning complete!") + console.print("Finetuning complete!") @staticmethod def finetune_found(weights_path: str): console.print(f"Fine-Tuned Model Found at {weights_path}... skipping training") """ - INFERENCE + INFERENCE """ # Lifecycle functions @@ -167,7 +161,7 @@ def inference_stream_display(text: Text): return LiveContext(text) """ - QA + QA """ # Lifecycle functions @@ -188,10 +182,7 @@ def qa_found(): pass @staticmethod - def qa_display_table( - self, result_dictionary, mean_values, median_values, stdev_values - ): - + def qa_display_table(self, result_dictionary, mean_values, median_values, stdev_values): # Create a table table = Table(show_header=True, header_style="bold", title="Test Results") diff --git a/src/ui/ui.py b/src/ui/ui.py index 59d5997..87d8fee 100644 --- a/src/ui/ui.py +++ b/src/ui/ui.py @@ -61,7 +61,7 @@ def finetune_found(weights_path: str): pass """ - INFERENCE + INFERENCE """ # Lifecycle functions @@ -91,7 +91,7 @@ def inference_stream_display(text: Text): pass """ - QA + QA """ # Lifecycle functions diff --git a/src/utils/ablation_utils.py b/src/utils/ablation_utils.py index 37d9d80..062892e 100644 --- a/src/utils/ablation_utils.py +++ b/src/utils/ablation_utils.py @@ -1,9 +1,6 @@ import copy import itertools -from typing import List, Type, Any, Dict, Optional, Union, Tuple -from typing import get_args, get_origin, get_type_hints - -import yaml +from typing import Dict, Tuple, Union, get_args, get_origin # TODO: organize this a little bit. It's a bit of a mess rn. @@ -14,17 +11,11 @@ """ -def get_types_from_dict( - source_dict: dict, root="", type_dict={} -) -> Dict[str, Tuple[type, type]]: +def get_types_from_dict(source_dict: dict, root="", type_dict={}) -> Dict[str, Tuple[type, type]]: for key, val in source_dict.items(): - if type(val) is not dict: + if not isinstance(val, dict): attr = f"{root}.{key}" if root else key - tp = ( - (type(val), None) - if type(val) is not list - else (type(val), type(val[0])) - ) + tp = (type(val), None) if not isinstance(val, list) else (type(val), type(val[0])) type_dict[attr] = tp else: join_array = [root, key] if root else [key] diff --git a/src/utils/rich_print_utils.py b/src/utils/rich_print_utils.py index 371f742..d39cc4c 100644 --- a/src/utils/rich_print_utils.py +++ b/src/utils/rich_print_utils.py @@ -1,7 +1,7 @@ -from rich.panel import Panel from rich.layout import Layout -from rich.text import Text +from rich.panel import Panel from rich.table import Table +from rich.text import Text def inject_example_to_rich_layout(layout: Layout, layout_name: str, example: dict): diff --git a/src/utils/save_utils.py b/src/utils/save_utils.py index 493b3c9..e84ed09 100644 --- a/src/utils/save_utils.py +++ b/src/utils/save_utils.py @@ -4,20 +4,19 @@ 2. Check if files are present at various experiment stages """ -import shutil +import hashlib import os -from os.path import exists -import yaml - import re -import hashlib -from functools import cached_property from dataclasses import dataclass +from functools import cached_property +from os.path import exists +import yaml from sqids import Sqids from src.pydantic_models.config_model import Config + NUM_MD5_DIGITS_FOR_SQIDS = 5 # TODO: maybe move consts to a dedicated folder diff --git a/toolkit.py b/toolkit.py index 830fc14..4a46798 100644 --- a/toolkit.py +++ b/toolkit.py @@ -1,20 +1,21 @@ -from os import listdir -from os.path import join, exists -import yaml import logging +from os import listdir +from os.path import exists, join -from transformers import utils as hf_utils -from pydantic import ValidationError import torch import typer +import yaml +from pydantic import ValidationError +from transformers import utils as hf_utils -from src.pydantic_models.config_model import Config from src.data.dataset_generator import DatasetGenerator -from src.utils.save_utils import DirectoryHelper -from src.utils.ablation_utils import generate_permutations from src.finetune.lora import LoRAFinetune from src.inference.lora import LoRAInference +from src.pydantic_models.config_model import Config from src.ui.rich_ui import RichUI +from src.utils.ablation_utils import generate_permutations +from src.utils.save_utils import DirectoryHelper + hf_utils.logging.set_verbosity_error() torch._logging.set_logs(all=logging.CRITICAL) @@ -32,7 +33,7 @@ def run_one_experiment(config: Config, config_path: str) -> None: with RichUI.during_dataset_creation("Injecting Values into Prompt", "monkey"): dataset_generator = DatasetGenerator(**config.data.model_dump()) - train_columns = dataset_generator.train_columns + _ = dataset_generator.train_columns test_column = dataset_generator.test_column dataset_path = dir_helper.save_paths.dataset @@ -66,9 +67,8 @@ def run_one_experiment(config: Config, config_path: str) -> None: results_path = dir_helper.save_paths.results results_file_path = join(dir_helper.save_paths.results, "results.csv") if not exists(results_path) or exists(results_file_path): - inference_runner = LoRAInference( - test, test_column, config, dir_helper - ).infer_all() + inference_runner = LoRAInference(test, test_column, config, dir_helper) + inference_runner.infer_all() RichUI.after_inference(results_path) else: RichUI.inference_found(results_path) @@ -90,9 +90,7 @@ def run(config_path: str = "./config.yml") -> None: with open(config_path, "r") as file: config = yaml.safe_load(file) configs = ( - generate_permutations(config, Config) - if config.get("ablation", {}).get("use_ablate", False) - else [config] + generate_permutations(config, Config) if config.get("ablation", {}).get("use_ablate", False) else [config] ) for config in configs: # validate data with pydantic