diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 65af62d..676b7d2 100755 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -10,10 +10,25 @@ // Features to add to the dev container. More info: https://containers.dev/features. "features": { - "ghcr.io/devcontainers/features/common-utils:2": {}, - "ghcr.io/devcontainers-contrib/features/poetry:2": {}, - "ghcr.io/devcontainers-contrib/features/nox:2": {}, - "ghcr.io/devcontainers-contrib/features/pre-commit:2": {} + "ghcr.io/devcontainers/features/common-utils:2": { + "installZsh": true, + "installOhMyZsh": true, + "installOhMyZshConfig": true, + "upgradePackages": true, + "username": "devcontainer" + }, + "ghcr.io/devcontainers-contrib/features/poetry:2": { + "version": "latest" + }, + "ghcr.io/devcontainers-contrib/features/nox:2": { + "version": "latest" + }, + "ghcr.io/devcontainers-contrib/features/pre-commit:2": { + "version": "latest" + }, + "ghcr.io/devcontainers-contrib/features/mypy:2": { + "version": "latest" + } }, // Use 'forwardPorts' to make a list of ports inside the container available locally. diff --git a/noxfile.py b/noxfile.py index ece230c..b3a4112 100644 --- a/noxfile.py +++ b/noxfile.py @@ -40,10 +40,18 @@ def ruff(session): @nox.session def test(session): # Not certain this is a good approach. But it currently works. - # session.install("pytest") - # session.install("pytest-cov") - session.run("pytest", "--cov=quke", "tests/") - # test_files = session.posargs if session.posargs else [] - # session.run("pytest", "--cov=quke", *test_files) + # TODO: test_files = session.posargs if session.posargs else [] + # TODO: session.run("pytest", "--cov=quke", *test_files) + + +@nox.session +def mypy(session): + session.install("mypy") + session.run( + "mypy", + "./quke", + "--python-executable", + "/home/vscode/.cache/pypoetry/virtualenvs/quke-61FoJWY3-py3.11/bin/python", + ) diff --git a/poetry.lock b/poetry.lock index c0dc6b5..fff4018 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3143,4 +3143,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6b037c9ccd135d742aa5c851d8d512f64f3691645c01e690055fa922344ffb21" +content-hash = "de1038901aa1a3fbb0227a7a99be9616df4558f604d0cd2e63bc76ce88d55ba8" diff --git a/pyproject.toml b/pyproject.toml index a900c1b..1e97062 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quke" -version = "0.2.0" +version = "0.2.1" description = "Compare the answering capabilities of different LLMs - for example LlaMa, ChatGPT, Cohere, Falcon - against user provided document(s) and questions." authors = ["Erik Oosterop"] maintainers = ["Erik Oosterop"] @@ -36,14 +36,10 @@ replicate = "^0.9.0" mdutils = "^1.6.0" rich = "^13.5.2" - -[tool.poetry.dev-dependencies] -pytest = "^7.4.0" -requests-mock = "^1.11.0" - - [tool.poetry.group.dev.dependencies] pytest-cov = "^4.1.0" +pytest = "^7.4.0" +requests-mock = "^1.11.0" [build-system] requires = ["poetry-core"] @@ -61,22 +57,45 @@ line-length = 119 select = [ # https://beta.ruff.rs/docs/rules/ "A", # prevent using keywords that clobber python builtins "ANN", # type annotation + "ARG", # unused arguments "B", # bugbear: security warnings + # "BLE", # blind exceptions "C", - "C90", + # "COM", # commas + "C4", # comprehension + "C90", # McCabe complexity "D", # pydocstyle # "DAR", # darglint, but does not seem to be implemented at the moment "DTZ", # date timezone "E", # pycodestyle + "EM", # error messages + "ERA", # eradicate + "EXE", # executables "F", # pyflakes + "FLY", # f-strings + # "G", # logging format (no f-string) "I", # isort + "ICN", # import conventions + "INT", # gettext "ISC", # implicit string concatenation + "N", # pep8 naming + "PERF", # performance lint + "PIE", # "PT", # pytest style "PTH", # use pathlib - "Q", + "Q", # quotes + "RET", # return values + "RSE", # error parenthesis + "RUF", # ruff rules "S", # Bandit "SIM", # simplify + "TCH", # type checking + # "TD", # TODO + "TID", # tidy imports + "TRY", # tryceratops + "T20", # print statement "UP", # alert you when better syntax is available in your python version + "W", # pycodestyle warnings "RUF", # the ruff developer's own rules ] @@ -98,7 +117,7 @@ fixable = [ [tool.ruff.per-file-ignores] "tests/**/*.py" = [ - # at least this three should be fine in tests: + # at least these three should be fine in tests: "S101", # asserts allowed in tests... "ANN", # TODO: do not care about type annotations in tests for now "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... @@ -107,6 +126,7 @@ fixable = [ "PLR2004", # Magic value used in comparison, ... "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "D", # no pydocstyle + "N", # Argument, function to lowercase ] "noxfile.py" = [ @@ -115,4 +135,7 @@ fixable = [ ] [tool.ruff.pydocstyle] -convention = "google" \ No newline at end of file +convention = "google" + +[tool.mypy] +disallow_incomplete_defs = true diff --git a/quke/embed.py b/quke/embed.py index b4608fe..41d8432 100644 --- a/quke/embed.py +++ b/quke/embed.py @@ -7,6 +7,7 @@ from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path +from typing import Iterator # [ ] TODO: PyMU is faster, PyPDF more accurate: https://github.com/py-pdf/benchmarks from langchain.document_loaders import CSVLoader, PyMuPDFLoader, TextLoader @@ -20,7 +21,10 @@ class DocumentLoaderDef: ext: str = "pdf" loader: object = PyMuPDFLoader - kwargs: defaultdict[dict] = field(default_factory=dict) # empty dict + # TODO: Remove this - kwargs: defaultdict[dict] = field(default_factory=dict) # empty dict + kwargs: dict[str, str] = field( + default_factory=lambda: defaultdict(dict) + ) # empty dict DOC_LOADERS = [ @@ -41,6 +45,9 @@ def get_loaders(src_doc_folder: str, loader: DocumentLoaderDef) -> list: src_doc_folder: The folder of the source files. loader: Definition of the loader. Loaders exist for example for pdf, text and csv files. + + Returns: + A list of loaders to be used to read the text from source documents. """ ext = loader.ext @@ -51,15 +58,20 @@ def get_loaders(src_doc_folder: str, loader: DocumentLoaderDef) -> list: # TODO: Problem with embedding more than 2 files at once, or some number of pages/chunks (using HF)? # Error message does not really help. Appending in steps does work. - loaders = [ + return [ loader.loader(str(pdf_name), **loader.kwargs) for pdf_name in src_file_names ] - return loaders - def get_pages_from_document(src_doc_folder: str) -> list: - """Reads documents from the directory/folder provided and returns a list of pages and metadata.""" + """Reads documents from the directory/folder provided and returns a list of pages and metadata. + + Args: + src_doc_folder: Folder containing the source documents. + + Returns: + List containing one page per list item, as text. + """ pages = [] for docloader in DOC_LOADERS: for loader in get_loaders(src_doc_folder, docloader): @@ -79,9 +91,20 @@ def get_pages_from_document(src_doc_folder: str) -> list: def get_chunks_from_pages(pages: list, splitter_params: dict) -> list: - """Splits pages into smaller chunks used for embedding.""" - # for splitter args containing 'func', the yaml value is converted into a Python function. - # TODO: Security risk? Hence a safe_list of functions is provided; severly limiting flexibility. + """Splits pages into smaller chunks used for embedding. + + Args: + pages: List with page text of a document(s). + splitter_params: Dictionary with settings for splitting logic, having + keys splitter_args and splitter_import. + splitter_args are provided to the splitter function as **kwargs. Note that if a keyword + contains 'func' the value will be evaluated as a python function (only 'len' allowed). + + Returns: + A list of smaller text chunks from the pages. In a next step to be used for embedding. + """ + # TODO: eval() is a security risk. Hence a safe_list of functions is provided; severely + # limiting risk and flexibility. # TODO: The other limiting factor: any parameter containing 'func' is eval()-ed into a function reference; # also no other parameter is. safe_function_list = ["len"] @@ -115,7 +138,22 @@ def embed( splitter_params: dict, write_mode: DatabaseAction = DatabaseAction.NO_OVERWRITE, ) -> int: - """Reads documents from a provided directory, performs embedding and captures the embeddings in a vector store.""" + """Reads documents from a provided directory, performs embedding and captures the embeddings in a vector store. + + Args: + src_doc_folder: Folder containing the source documents. + vectordb_location (str): Folder of vector store database. + embedding_import: Definition for embedding model. + embedding_kwargs: **kwargs to be provided to embedding class. + vectordb_import: Definition of vector store. + rate_limit: Rate limiting info. Used as a basic limiter dealing with 3rd party API limits. + splitter_params: Specifications for text splitting logic. + write_mode: Wether to OVERWRITE, APPEND or NO_OVERWRITE the vector store. NO_OVERWRITE will + not embed anything if a vector store exists at the vectordb_location. + + Returns: + The number of text chunks embedded. + """ logging.info(f"Starting to embed into VectorDB: {vectordb_location}") # if folder does not exist, or write_mode is APPEND no need to do anything here. @@ -131,7 +169,7 @@ def embed( f"{vectordb_location!r}. Remove database folder, or change embedding config " "vectorstore_write_mode to OVERWRITE or APPEND." ) - return + return 0 if ( write_mode == DatabaseAction.OVERWRITE ): # remove exising database before embedding @@ -153,7 +191,7 @@ def embed( ) # Use chunker to embed in chunks with a wait time in between. As a basic way to deal with some rate limiting. - def chunker(seq: list, size: int) -> list: + def chunker(seq: list, size: int) -> Iterator[list]: return (seq[pos : pos + size] for pos in range(0, len(seq), size)) c = 0 @@ -181,7 +219,18 @@ def embed_these_chunks( embedding_kwargs: dict, vectordb_import: ClassImportDefinition, ) -> int: - """Embed the provided chunks and capture into a vector store.""" + """Embed the provided chunks and capture into a vector store. + + Args: + chunks: List of text chunks to be embedded. + vectordb_location: Location of the folder containing the embedding database. + embedding_import: Definition of embedding model ('to build Python import statement'). + embedding_kwargs: Dictionary provided as **kwargs for embedding class. + vectordb_import: Definition of vector store ('to build Python import statement'). + + Returns: + Number of chunks embedded and captured in vector store. + """ module = importlib.import_module(embedding_import.module_name) class_ = getattr(module, embedding_import.class_name) embedding = class_(**embedding_kwargs) diff --git a/quke/llm_chat.py b/quke/llm_chat.py index 1468ee6..97dd269 100644 --- a/quke/llm_chat.py +++ b/quke/llm_chat.py @@ -7,8 +7,8 @@ from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory -from mdutils.fileutils import MarkDownFile -from mdutils.mdutils import MdUtils +from mdutils.fileutils import MarkDownFile # type: ignore +from mdutils.mdutils import MdUtils # type: ignore from . import ClassImportDefinition @@ -27,6 +27,18 @@ def chat( Sets up all components required for the chat including the LLM, embedding model, vector store, chat memory, retriever and actual questions. + + Args: + vectordb_location: Folder of vector store. + embedding_import: Definition of embedding model. + vectordb_import: Definition of vector store. + llm_import: Definition of LLM. + llm_parameters: dict provided as **kwargs to LLM model class. + prompt_parameters: List of questions to ask the LLM. + output_file: Folder where result file will be saved. + + Returns: + Object containing chat history. """ module = importlib.import_module(embedding_import.module_name) class_ = getattr(module, embedding_import.class_name) @@ -62,12 +74,7 @@ def chat( for question in prompt_parameters: result = qa({"question": question}) - # print(qa({"question": question})) - - # erik = source_documents - # print(f"==============={erik}") - - chat_output(question, result) + chat_output(result) chat_output_to_file(result, output_file) logging.info("=======================") @@ -75,66 +82,83 @@ def chat( return qa -def chat_output(question: str, result: dict) -> None: - """Logs a chat question and anwer.""" +def chat_output(result: dict) -> None: + """Logs a chat question and anwer. + + Args: + result: dict with the answer from the LLM. Expects 'question', 'answer' and 'source' keys, + 'page' key optionally. + """ logging.info("=======================") - logging.info(f"Q: {question}") + logging.info(f"Q: {result['question']}") logging.info(f"A: {result['answer']}") src_docs = [doc.metadata for doc in result["source_documents"]] - src_docs = dict_crosstab(src_docs, "source", "page") - for key, value in src_docs.items(): + src_docs_pages_used = dict_crosstab(src_docs, "source", "page") + for key, value in src_docs_pages_used.items(): logging.info(f"Source document: {key}, Pages used: {value}") # TODO: Either I do not understand mdutils or it is an unfriendly package when trying to append. def chat_output_to_file(result: dict, output_file: dict) -> None: - """Populates a record of the chat with the LLM into a markdown file.""" + """Populates a record of the chat with the LLM into a markdown file. + + Args: + result: dict with the answer from the LLM. Expects 'question', 'answer' and 'source' keys, + 'page' key optionally. + output_file: File name to which the record is saved. + """ first_write = not Path(output_file["path"]).is_file() - mdFile = MdUtils(file_name="tmp.md") + md_file = MdUtils(file_name="tmp.md") if first_write: - mdFile.new_header(1, "LLM Chat Session with quke") - mdFile.write( + md_file.new_header(1, "LLM Chat Session with quke") + md_file.write( datetime.now().astimezone().strftime("%a %d-%b-%Y %H:%M %Z"), align="center" ) - mdFile.new_paragraph("") - mdFile.new_header(2, "Experiment settings", header_id="settings") - mdFile.insert_code(output_file["conf_yaml"], language="yaml") - mdFile.new_header(2, "Chat", header_id="chat") + md_file.new_paragraph("") + md_file.new_header(2, "Experiment settings", header_id="settings") + md_file.insert_code(output_file["conf_yaml"], language="yaml") + md_file.new_header(2, "Chat", header_id="chat") else: existing_text = MarkDownFile().read_file(file_name=output_file["path"]) - mdFile.new_paragraph(existing_text) + md_file.new_paragraph(existing_text) - mdFile.new_paragraph(f"Q: {result['question']}") - mdFile.new_paragraph(f"A: {result['answer']}") + md_file.new_paragraph(f"Q: {result['question']}") + md_file.new_paragraph(f"A: {result['answer']}") src_docs = [doc.metadata for doc in result["source_documents"]] - src_docs = dict_crosstab(src_docs, "source", "page") - for key, value in src_docs.items(): - mdFile.new_paragraph(f"Source document: {key}, Pages used: {value}") + src_docs_pages_used = dict_crosstab(src_docs, "source", "page") + for key, value in src_docs_pages_used.items(): + md_file.new_paragraph(f"Source document: {key}, Pages used: {value}") new = MarkDownFile(name=output_file["path"]) - new.append_end((mdFile.get_md_text()).strip()) + new.append_end((md_file.get_md_text()).strip()) def dict_crosstab(source: list, key: str, listed: str, missing: str = "NA") -> dict: """Limited and simple version of a crosstab query on a dict. Args: - source (list): _description_ - key (str): _description_ - listed (str): _description_ - missing (str, optional): _description_. Defaults to "NA". + source: List of dicts. Two elements per dict will be considered: 'key' and 'listed'. + key: Every dict should contain an entry for 'key'. + listed: The key for the element in the dict considered to contain the value. + missing: Value to be used if dict has no key for 'listed'. Returns: - _type_: _description_ + A dictionary containing 'keys' and a list of values for each 'key'. + + >>> a = {'name': 'a', 'number': 2} + >>> b = {'name': 'a', 'number': 3, 'number_2': 3} + >>> c = {'name': 'a', 'number': 2} + >>> d = {'name': 'd', 'number': 1} + >>> e = {'name': 'e', 'number_3': 2} + >>> dict_crosstab([e, b, c, d, a], 'name', 'number') + {'e': ['NA'], 'a': [2, 3], 'd': [1]} """ - dict_subs = [] - for d in source: - dict_subs.append({key: d[key], listed: d.get(listed, missing)}.values()) + dict_subs = [{key: d[key], listed: d.get(listed, missing)}.values() for d in source] d = defaultdict(list) for k, v in dict_subs: diff --git a/quke/quke.py b/quke/quke.py index 86c7f01..fed88dc 100644 --- a/quke/quke.py +++ b/quke/quke.py @@ -64,7 +64,7 @@ def __init__(self, cfg: DictConfig) -> None: (cfg.embedding.vectordb.vectorstore_write_mode).upper() ] except Exception: - logging.warn( + logging.warning( f"Invalid value configured for cfg.embedding.vectorstore_write_mode: " f"{cfg.embedding.vectordb.vectorstore_write_mode}. Using no_overwrite instead." ) @@ -87,13 +87,12 @@ def __init__(self, cfg: DictConfig) -> None: Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / cfg.experiment_summary_file ) - except Exception: self.output_file = cfg.experiment_summary_file def get_embed_params(self) -> dict: """Based on the config files returns the set of parameters need to start embedding.""" - embed_parameters = { + return { "src_doc_folder": self.src_doc_folder, "vectordb_location": self.vectordb_location, "embedding_import": self.embedding_import, @@ -103,11 +102,10 @@ def get_embed_params(self) -> dict: "splitter_params": self.get_splitter_params(), "write_mode": self.write_mode, } - return embed_parameters def get_chat_params(self) -> dict: """Based on the config files returns the set of parameters need to start a chat.""" - chat_parameters = { + return { "vectordb_location": self.vectordb_location, "embedding_import": self.embedding_import, "vectordb_import": self.vectordb_import, @@ -116,7 +114,6 @@ def get_chat_params(self) -> dict: "prompt_parameters": self.questions, "output_file": self.get_chat_session_file_parameters(self.cfg), } - return chat_parameters def get_splitter_params(self) -> dict: """Based on the config files returns the set of parameters needed to split source documents.""" @@ -127,19 +124,20 @@ def get_splitter_params(self) -> dict: def get_args_dict(self, cfg_sub: dict) -> dict: """Takes a subset of the Hydra configs and returns the same as a dict.""" - return OmegaConf.to_container(cfg_sub, resolve=True) + res = OmegaConf.to_container(cfg_sub, resolve=True) + return res if isinstance(res, dict) else {} def get_llm_parameters(self) -> dict: """Based on the config files returns the set of parameters needed to setup an LLM.""" - return OmegaConf.to_container(self.cfg.llm.llm_args, resolve=True) + res = OmegaConf.to_container(self.cfg.llm.llm_args, resolve=True) + return res if isinstance(res, dict) else {} def get_chat_session_file_parameters(self, cfg: DictConfig) -> dict: """Returns the full configuration in a single yaml and file location for output.""" - chat_sesion_file_parameters = { + return { "path": self.output_file, "conf_yaml": OmegaConf.to_yaml(cfg), } - return chat_sesion_file_parameters def get_embedding_kwargs(self, cfg: DictConfig) -> dict: """Based on the config files returns the set of parameters needed for embedding.""" @@ -147,12 +145,6 @@ def get_embedding_kwargs(self, cfg: DictConfig) -> dict: embedding_kwargs = ( cfg.embedding.embedding.kwargs if cfg.embedding.embedding.kwargs else {} ) - """ - if cfg.embedding.embedding.kwargs: - embedding_kwargs = cfg.embedding.embedding.kwargs - else: - embedding_kwargs = {} - """ except Exception: embedding_kwargs = {} return embedding_kwargs @@ -163,6 +155,8 @@ def quke(cfg: DictConfig) -> None: """The main function to initiate a chat. Including the embedding of the provided source documents. + + Questions, LLM, embedding model, vectordb are specified in config files (using Hydra). """ console = Console() config_parser = ConfigParser(cfg) diff --git a/tests/test_001.py b/tests/test_001.py index 0dbcb01..e6184df 100644 --- a/tests/test_001.py +++ b/tests/test_001.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig from quke.embed import embed, get_chunks_from_pages, get_pages_from_document -from quke.llm_chat import chat +from quke.llm_chat import chat, dict_crosstab from quke.quke import ConfigParser OUTPUT_FILE = "chat_session.md" @@ -18,7 +18,7 @@ @pytest.fixture(scope="session") def GetConfigEmbedOnly(): with initialize(version_base=None, config_path="./conf"): - cfg = compose( + return compose( config_name="config", overrides=[ "embed_only=True", @@ -27,16 +27,14 @@ def GetConfigEmbedOnly(): "embedding.vectordb.vectorstore_write_mode=overwrite", ], ) - return cfg @pytest.fixture(scope="session") def GetConfigLLMOnly(tmp_path_factory: pytest.TempPathFactory): folder = tmp_path_factory.mktemp("output") - # output_file = os.path.join(folder, OUTPUT_FILE) output_file = Path(folder) / OUTPUT_FILE with initialize(version_base=None, config_path="./conf"): - cfg = compose( + return compose( config_name="config", overrides=[ "embed_only=False", @@ -45,7 +43,6 @@ def GetConfigLLMOnly(tmp_path_factory: pytest.TempPathFactory): "embedding.vectordb.vectorstore_write_mode=no_overwrite", ], ) - return cfg @pytest.fixture(scope="session") @@ -60,6 +57,17 @@ def GetChunks(GetPages: list, GetConfigEmbedOnly: DictConfig) -> list: ) +@pytest.fixture(scope="session") +def GetCrossTabDicts() -> list: + a = {"name": "a", "number": 2} + b = {"name": "a", "number": 3, "number_2": 3} + c = {"name": "a", "number": 2} + d = {"name": "d", "number": 1} + e = {"name": "e", "number_3": 2} + + return [e, b, c, d, a] + + def test_documentloader(GetPages: list): assert len(GetPages) > 0 @@ -101,3 +109,8 @@ def test_chat(GetConfigLLMOnly: DictConfig): chat_result = chat(**ConfigParser(GetConfigLLMOnly).get_chat_params()) assert isinstance(chat_result, ConversationalRetrievalChain) assert Path(ConfigParser(GetConfigLLMOnly).output_file).is_file() + + +def test_crosstab_dict(GetCrossTabDicts: list): + x_result = dict_crosstab(GetCrossTabDicts, "name", "number") + assert x_result == {"e": ["NA"], "a": [2, 3], "d": [1]}