diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index efdc83a..65af62d 100755
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -7,22 +7,13 @@
"dockerfile": "./dockerfile",
"context": "."
},
-
+
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
- "ghcr.io/devcontainers/features/common-utils:2": {
- "installZsh": true,
- "installOhMyZsh": true,
- "installOhMyZshConfig": true,
- "upgradePackages": true,
- "username": "devcontainer"
- },
- "ghcr.io/devcontainers-contrib/features/poetry:2": {
- "version": "latest"
- },
- "ghcr.io/devcontainers-contrib/features/nox:2": {
- "version": "latest"
- }
+ "ghcr.io/devcontainers/features/common-utils:2": {},
+ "ghcr.io/devcontainers-contrib/features/poetry:2": {},
+ "ghcr.io/devcontainers-contrib/features/nox:2": {},
+ "ghcr.io/devcontainers-contrib/features/pre-commit:2": {}
},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
@@ -30,7 +21,8 @@
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "pip3 install --user -r requirements.txt",
- "postCreateCommand": "poetry install" ,
+ "postCreateCommand": "poetry install && pre-commit install" ,
+ // "postCreateCommand": "poetry install" ,
// Configure tool-specific properties.
"customizations": {
@@ -38,7 +30,12 @@
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
- "Gruntfuggly.todo-tree"
+ "Gruntfuggly.todo-tree",
+ "GitHub.vscode-pull-request-github",
+ "ms-python.black-formatter",
+ "ms-python.flake8",
+ "ms-python.isort",
+ "njpwerner.autodocstring"
]
}
},
@@ -46,5 +43,5 @@
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
"remoteUser": "vscode"
-
+
}
diff --git a/.gitignore b/.gitignore
index 283dbe3..4edadf3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ docs/youtube/
docs/src_docs/
docs/pdf - Copy/
.chroma/
+.ruff_cache/
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..d68ce5e
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,14 @@
+exclude: '^$'
+fail_fast: false
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v1.2.3
+ hooks:
+ - id: trailing-whitespace
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.0.285
+ hooks:
+ - id: ruff
+ # args: [--fix, --exit-non-zero-on-fix]
+ args: [--fix]
\ No newline at end of file
diff --git a/README.md b/README.md
index 1f58850..915f4d7 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,7 @@
-
+
quke
@@ -78,7 +78,7 @@
Getting Started
@@ -110,7 +110,7 @@
## About The Project
Compare the answering capabilities of different LLMs - for example LlaMa, ChatGPT, Cohere, Falcon - against user provided document(s) and questions.
-Specify the different models, embedding tools and vector databases in configuration files.
+Specify the different models, embedding tools and vector databases in configuration files.
Maintain reproducable experiments reflecting combinations of these configurations.
@@ -122,7 +122,7 @@ Maintain reproducable experiments reflecting combinations of these configuration
The instructions assume a Python environment with [Poetry][poetry-url] installed. Development of the tool is done in Python 3.11. While Poetry is not actually needed for the tool to function, the examples assume Poetry is installed.
#### API keys
-The tool uses 3rd party hosted inference APIs. API keys need to be specified as environment variables.
+The tool uses 3rd party hosted inference APIs. API keys need to be specified as environment variables.
The services used:
- [HuggingFace][huggingface-url]
@@ -130,29 +130,29 @@ The services used:
- [Cohere][cohere-url]
- [Replicate][replicate-url]
-The API keys can be specied in a [.env file][.env-url]. Use the provided .env.example file as an example (enter your own API keys and rename it to '.env').
+The API keys can be specied in a [.env file][.env-url]. Use the provided .env.example file as an example (enter your own API keys and rename it to '.env').
At present, all services used in the example configuration have free tiers available.
(back to top)
### Installation
-Navigate to the directory that contains the pyproject.toml file, then execute the
+Navigate to the directory that contains the pyproject.toml file, then execute the
```sh
poetry install
-```
+```
command.
(back to top)
## Usage
-For the examples the project comes with a public financial document for a Canadian Bank (CIBC) as source pdf file.
+For the examples the project comes with a public financial document for a Canadian Bank (CIBC) as source pdf file.
### Base
In order to run the first example, ensure to specify your HuggingFace API key.
-Use the command
+Use the command
```sh
poetry run quke
```
@@ -167,9 +167,9 @@ The defaults are specified in the config.yaml file (in the ./quke/conf/ director
### Specify models and embeddings
*Ensure to specify your Cohere API key before running.*
-As per the configuration files, the default LLM is Falcon and the default embedding uses HuggingFace embedding.
+As per the configuration files, the default LLM is Falcon and the default embedding uses HuggingFace embedding.
-To specify a different LLM - Cohere in this example - run the following:
+To specify a different LLM - Cohere in this example - run the following:
```sh
poetry run quke embedding=huggingface llm=cohere question=eps
```
@@ -192,7 +192,7 @@ poetry run quke embedding=huggingface llm=cohere question=eps
The LLMs, embeddings, questions and other configurations can be captured in experiment config files. The command
```sh
poetry run quke +experiment=openai
-```
+```
uses an experiment file openai.yaml (see folder ./config/experiments) which specifies the LLM, embedding and questions to be used. It is equivalent to running:
```sh
poetry run quke embedding=openai llm=gpt3-5 question=eps
@@ -227,7 +227,7 @@ Note to set `vectorstore_write_mode` to `append` or `overwrite` in the embedding
### Limitations
The free tiers for the third party services generally come with fairly strict limitations. They differ between services; and may differ over time.
- To try out the tool with your own documents it is best to start with a single small source document, no more than two questions and only one combination of LLM/embedding.
+ To try out the tool with your own documents it is best to start with a single small source document, no more than two questions and only one combination of LLM/embedding.
Error messages due to limitations of the APIs are not always clearly indicated as such.
@@ -240,7 +240,7 @@ The tool uses the [LangChain][langchain-url] Python package to interact with the
In general I do not know to what extent any of the data is encrypted during transmission.
-The tool shares no information with me.
+The tool shares no information with me.
(back to top)
diff --git a/noxfile.py b/noxfile.py
index 697f949..ece230c 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -4,8 +4,11 @@
# especially look at roughly line 51, flake/lint
import nox
+nox.options.sessions = ["black", "ruff"]
+
# TODO: Is this an option: https://nox-poetry.readthedocs.io/en/stable/
+# TODO: or a better option: https://github.com/pdm-project/pdm (instead of Poetry)
@nox.session
def flake(session):
session.install(
@@ -38,11 +41,9 @@ def ruff(session):
def test(session):
# Not certain this is a good approach. But it currently works.
# session.install("pytest")
+ # session.install("pytest-cov")
- if session.posargs:
- test_files = session.posargs
- else:
- test_files = []
+ session.run("pytest", "--cov=quke", "tests/")
- session.run("pytest", *test_files)
- # session.run("pytest")
+ # test_files = session.posargs if session.posargs else []
+ # session.run("pytest", "--cov=quke", *test_files)
diff --git a/poetry.lock b/poetry.lock
index be4d9af..c0dc6b5 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -705,6 +705,70 @@ humanfriendly = ">=9.1"
[package.extras]
cron = ["capturer (>=2.4)"]
+[[package]]
+name = "coverage"
+version = "7.3.0"
+description = "Code coverage measurement for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "coverage-7.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:db76a1bcb51f02b2007adacbed4c88b6dee75342c37b05d1822815eed19edee5"},
+ {file = "coverage-7.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c02cfa6c36144ab334d556989406837336c1d05215a9bdf44c0bc1d1ac1cb637"},
+ {file = "coverage-7.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477c9430ad5d1b80b07f3c12f7120eef40bfbf849e9e7859e53b9c93b922d2af"},
+ {file = "coverage-7.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce2ee86ca75f9f96072295c5ebb4ef2a43cecf2870b0ca5e7a1cbdd929cf67e1"},
+ {file = "coverage-7.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68d8a0426b49c053013e631c0cdc09b952d857efa8f68121746b339912d27a12"},
+ {file = "coverage-7.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b3eb0c93e2ea6445b2173da48cb548364f8f65bf68f3d090404080d338e3a689"},
+ {file = "coverage-7.3.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:90b6e2f0f66750c5a1178ffa9370dec6c508a8ca5265c42fbad3ccac210a7977"},
+ {file = "coverage-7.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:96d7d761aea65b291a98c84e1250cd57b5b51726821a6f2f8df65db89363be51"},
+ {file = "coverage-7.3.0-cp310-cp310-win32.whl", hash = "sha256:63c5b8ecbc3b3d5eb3a9d873dec60afc0cd5ff9d9f1c75981d8c31cfe4df8527"},
+ {file = "coverage-7.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:97c44f4ee13bce914272589b6b41165bbb650e48fdb7bd5493a38bde8de730a1"},
+ {file = "coverage-7.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:74c160285f2dfe0acf0f72d425f3e970b21b6de04157fc65adc9fd07ee44177f"},
+ {file = "coverage-7.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b543302a3707245d454fc49b8ecd2c2d5982b50eb63f3535244fd79a4be0c99d"},
+ {file = "coverage-7.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad0f87826c4ebd3ef484502e79b39614e9c03a5d1510cfb623f4a4a051edc6fd"},
+ {file = "coverage-7.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13c6cbbd5f31211d8fdb477f0f7b03438591bdd077054076eec362cf2207b4a7"},
+ {file = "coverage-7.3.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac440c43e9b479d1241fe9d768645e7ccec3fb65dc3a5f6e90675e75c3f3e3a"},
+ {file = "coverage-7.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3c9834d5e3df9d2aba0275c9f67989c590e05732439b3318fa37a725dff51e74"},
+ {file = "coverage-7.3.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4c8e31cf29b60859876474034a83f59a14381af50cbe8a9dbaadbf70adc4b214"},
+ {file = "coverage-7.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7a9baf8e230f9621f8e1d00c580394a0aa328fdac0df2b3f8384387c44083c0f"},
+ {file = "coverage-7.3.0-cp311-cp311-win32.whl", hash = "sha256:ccc51713b5581e12f93ccb9c5e39e8b5d4b16776d584c0f5e9e4e63381356482"},
+ {file = "coverage-7.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:887665f00ea4e488501ba755a0e3c2cfd6278e846ada3185f42d391ef95e7e70"},
+ {file = "coverage-7.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d000a739f9feed900381605a12a61f7aaced6beae832719ae0d15058a1e81c1b"},
+ {file = "coverage-7.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59777652e245bb1e300e620ce2bef0d341945842e4eb888c23a7f1d9e143c446"},
+ {file = "coverage-7.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9737bc49a9255d78da085fa04f628a310c2332b187cd49b958b0e494c125071"},
+ {file = "coverage-7.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5247bab12f84a1d608213b96b8af0cbb30d090d705b6663ad794c2f2a5e5b9fe"},
+ {file = "coverage-7.3.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ac9a1de294773b9fa77447ab7e529cf4fe3910f6a0832816e5f3d538cfea9a"},
+ {file = "coverage-7.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:85b7335c22455ec12444cec0d600533a238d6439d8d709d545158c1208483873"},
+ {file = "coverage-7.3.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:36ce5d43a072a036f287029a55b5c6a0e9bd73db58961a273b6dc11a2c6eb9c2"},
+ {file = "coverage-7.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:211a4576e984f96d9fce61766ffaed0115d5dab1419e4f63d6992b480c2bd60b"},
+ {file = "coverage-7.3.0-cp312-cp312-win32.whl", hash = "sha256:56afbf41fa4a7b27f6635bc4289050ac3ab7951b8a821bca46f5b024500e6321"},
+ {file = "coverage-7.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f297e0c1ae55300ff688568b04ff26b01c13dfbf4c9d2b7d0cb688ac60df479"},
+ {file = "coverage-7.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac0dec90e7de0087d3d95fa0533e1d2d722dcc008bc7b60e1143402a04c117c1"},
+ {file = "coverage-7.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:438856d3f8f1e27f8e79b5410ae56650732a0dcfa94e756df88c7e2d24851fcd"},
+ {file = "coverage-7.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1084393c6bda8875c05e04fce5cfe1301a425f758eb012f010eab586f1f3905e"},
+ {file = "coverage-7.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49ab200acf891e3dde19e5aa4b0f35d12d8b4bd805dc0be8792270c71bd56c54"},
+ {file = "coverage-7.3.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67e6bbe756ed458646e1ef2b0778591ed4d1fcd4b146fc3ba2feb1a7afd4254"},
+ {file = "coverage-7.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f39c49faf5344af36042b293ce05c0d9004270d811c7080610b3e713251c9b0"},
+ {file = "coverage-7.3.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7df91fb24c2edaabec4e0eee512ff3bc6ec20eb8dccac2e77001c1fe516c0c84"},
+ {file = "coverage-7.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:34f9f0763d5fa3035a315b69b428fe9c34d4fc2f615262d6be3d3bf3882fb985"},
+ {file = "coverage-7.3.0-cp38-cp38-win32.whl", hash = "sha256:bac329371d4c0d456e8d5f38a9b0816b446581b5f278474e416ea0c68c47dcd9"},
+ {file = "coverage-7.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b859128a093f135b556b4765658d5d2e758e1fae3e7cc2f8c10f26fe7005e543"},
+ {file = "coverage-7.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed8d310afe013db1eedd37176d0839dc66c96bcfcce8f6607a73ffea2d6ba"},
+ {file = "coverage-7.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61260ec93f99f2c2d93d264b564ba912bec502f679793c56f678ba5251f0393"},
+ {file = "coverage-7.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97af9554a799bd7c58c0179cc8dbf14aa7ab50e1fd5fa73f90b9b7215874ba28"},
+ {file = "coverage-7.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3558e5b574d62f9c46b76120a5c7c16c4612dc2644c3d48a9f4064a705eaee95"},
+ {file = "coverage-7.3.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d5576d35fcb765fca05654f66aa71e2808d4237d026e64ac8b397ffa66a56a"},
+ {file = "coverage-7.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:07ea61bcb179f8f05ffd804d2732b09d23a1238642bf7e51dad62082b5019b34"},
+ {file = "coverage-7.3.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:80501d1b2270d7e8daf1b64b895745c3e234289e00d5f0e30923e706f110334e"},
+ {file = "coverage-7.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4eddd3153d02204f22aef0825409091a91bf2a20bce06fe0f638f5c19a85de54"},
+ {file = "coverage-7.3.0-cp39-cp39-win32.whl", hash = "sha256:2d22172f938455c156e9af2612650f26cceea47dc86ca048fa4e0b2d21646ad3"},
+ {file = "coverage-7.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:60f64e2007c9144375dd0f480a54d6070f00bb1a28f65c408370544091c9bc9e"},
+ {file = "coverage-7.3.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:5492a6ce3bdb15c6ad66cb68a0244854d9917478877a25671d70378bdc8562d0"},
+ {file = "coverage-7.3.0.tar.gz", hash = "sha256:49dbb19cdcafc130f597d9e04a29d0a032ceedf729e41b181f51cd170e6ee865"},
+]
+
+[package.extras]
+toml = ["tomli"]
+
[[package]]
name = "dataclasses-json"
version = "0.5.9"
@@ -2106,6 +2170,24 @@ pluggy = ">=0.12,<2.0"
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+[[package]]
+name = "pytest-cov"
+version = "4.1.0"
+description = "Pytest plugin for measuring coverage."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
+ {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
+]
+
+[package.dependencies]
+coverage = {version = ">=5.2.1", extras = ["toml"]}
+pytest = ">=4.6"
+
+[package.extras]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+
[[package]]
name = "python-dateutil"
version = "2.8.2"
@@ -3061,4 +3143,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "9d35e1aac8dba36c9273f5d1556a086968e34fd53181913c2fb310b42edefebe"
+content-hash = "6b037c9ccd135d742aa5c851d8d512f64f3691645c01e690055fa922344ffb21"
diff --git a/pyproject.toml b/pyproject.toml
index c6125a4..a900c1b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quke"
-version = "0.1.3"
+version = "0.2.0"
description = "Compare the answering capabilities of different LLMs - for example LlaMa, ChatGPT, Cohere, Falcon - against user provided document(s) and questions."
authors = ["Erik Oosterop"]
maintainers = ["Erik Oosterop"]
@@ -41,6 +41,10 @@ rich = "^13.5.2"
pytest = "^7.4.0"
requests-mock = "^1.11.0"
+
+[tool.poetry.group.dev.dependencies]
+pytest-cov = "^4.1.0"
+
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
@@ -53,16 +57,15 @@ markers = [
[tool.ruff]
include = ["*.py", "*.pyi", "**/pyproject.toml"]
-fix = false
line-length = 119
select = [ # https://beta.ruff.rs/docs/rules/
"A", # prevent using keywords that clobber python builtins
"ANN", # type annotation
"B", # bugbear: security warnings
- "C",
- "C90",
+ "C",
+ "C90",
"D", # pydocstyle
- "DAR", # darglint, but does not seem to be implemented at the moment
+ # "DAR", # darglint, but does not seem to be implemented at the moment
"DTZ", # date timezone
"E", # pycodestyle
"F", # pyflakes
@@ -80,19 +83,24 @@ select = [ # https://beta.ruff.rs/docs/rules/
ignore = [
"E203", # comments allowed
"E501",
+ "ANN101", # type annotation for self
]
# fixing is off by default
+fix = true
fixable = [
"F401", # Remove unused imports.
"NPY001", # Fix numpy types, which are removed in 1.24.
"RUF100", # Remove unused noqa comments.
+ "I", # Fix import order
+ "PTH", # Path.cwd()
]
[tool.ruff.per-file-ignores]
"tests/**/*.py" = [
# at least this three should be fine in tests:
"S101", # asserts allowed in tests...
+ "ANN", # TODO: do not care about type annotations in tests for now
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
# The below are debateable
@@ -101,5 +109,10 @@ fixable = [
"D", # no pydocstyle
]
+"noxfile.py" = [
+ "ANN",
+ "D",
+]
+
[tool.ruff.pydocstyle]
convention = "google"
\ No newline at end of file
diff --git a/quke/conf/__init__.py b/quke/conf/__init__.py
index e69de29..2968641 100644
--- a/quke/conf/__init__.py
+++ b/quke/conf/__init__.py
@@ -0,0 +1 @@
+"""Hydra-type configuration files for quke LLM semantic search."""
diff --git a/quke/embed.py b/quke/embed.py
index 15b619e..b4608fe 100644
--- a/quke/embed.py
+++ b/quke/embed.py
@@ -6,7 +6,7 @@
import time
from collections import defaultdict
from dataclasses import dataclass, field
-from glob import glob
+from pathlib import Path
# [ ] TODO: PyMU is faster, PyPDF more accurate: https://github.com/py-pdf/benchmarks
from langchain.document_loaders import CSVLoader, PyMuPDFLoader, TextLoader
@@ -46,13 +46,14 @@ def get_loaders(src_doc_folder: str, loader: DocumentLoaderDef) -> list:
# to make ext case insensitive
ext = "".join([f"[{ch}{ch.swapcase()}]" for ch in ext])
- src_file_names = glob(
- os.path.join(src_doc_folder, "**", f"*.{ext}"), recursive=True
- ) # also looks in subfolders
+
+ src_file_names = Path(src_doc_folder).rglob(f"**/*.{ext}")
# TODO: Problem with embedding more than 2 files at once, or some number of pages/chunks (using HF)?
# Error message does not really help. Appending in steps does work.
- loaders = [loader.loader(pdf_name, **loader.kwargs) for pdf_name in src_file_names]
+ loaders = [
+ loader.loader(str(pdf_name), **loader.kwargs) for pdf_name in src_file_names
+ ]
return loaders
@@ -77,7 +78,7 @@ def get_pages_from_document(src_doc_folder: str) -> list:
return pages
-def get_chunks_from_pages(pages: list, splitter_params) -> list:
+def get_chunks_from_pages(pages: list, splitter_params: dict) -> list:
"""Splits pages into smaller chunks used for embedding."""
# for splitter args containing 'func', the yaml value is converted into a Python function.
# TODO: Security risk? Hence a safe_list of functions is provided; severly limiting flexibility.
@@ -88,7 +89,7 @@ def get_chunks_from_pages(pages: list, splitter_params) -> list:
if ("func".lower() in key.lower()) and splitter_params["splitter_args"][
key
] in safe_function_list:
- splitter_params["splitter_args"][key] = eval(
+ splitter_params["splitter_args"][key] = eval( # noqa: S307
splitter_params["splitter_args"][key]
)
@@ -118,28 +119,28 @@ def embed(
logging.info(f"Starting to embed into VectorDB: {vectordb_location}")
# if folder does not exist, or write_mode is APPEND no need to do anything here.
- if os.path.exists(vectordb_location) and not os.path.isfile( # noqa: SIM102
- vectordb_location
+ if (
+ Path(vectordb_location).exists()
+ and (not Path(vectordb_location).is_file())
+ and os.listdir(vectordb_location)
):
- if os.listdir(
- vectordb_location
- ): # path exists and is not empty - assumed to contain vectordb
- if write_mode == DatabaseAction.NO_OVERWRITE: # skip embedding
- logging.info(
- f"No new embeddings created. Embedding database already exists at "
- f"{vectordb_location!r}. Remove database folder, or change embedding config "
- "vectorstore_write_mode to OVERWRITE or APPEND."
- )
- return
- if (
- write_mode == DatabaseAction.OVERWRITE
- ): # remove exising database before embedding
- # TODO: Is this too harsh to delete the full folder? At least create a backup?
- logging.warning(
- f"The folder containing the embedding database ({vectordb_location}) and all its contents "
- "about to be overwritten."
- )
- shutil.rmtree(vectordb_location)
+ # path exists and is not empty - assumed to contain vectordb
+ if write_mode == DatabaseAction.NO_OVERWRITE: # skip embedding
+ logging.info(
+ f"No new embeddings created. Embedding database already exists at "
+ f"{vectordb_location!r}. Remove database folder, or change embedding config "
+ "vectorstore_write_mode to OVERWRITE or APPEND."
+ )
+ return
+ if (
+ write_mode == DatabaseAction.OVERWRITE
+ ): # remove exising database before embedding
+ # TODO: Is this too harsh to delete the full folder? At least create a backup?
+ logging.warning(
+ f"The folder containing the embedding database ({vectordb_location}) and all its contents "
+ "about to be overwritten."
+ )
+ shutil.rmtree(vectordb_location)
# get bite sized chunks from source documents
chunks = get_chunks_from_pages(
@@ -152,7 +153,7 @@ def embed(
)
# Use chunker to embed in chunks with a wait time in between. As a basic way to deal with some rate limiting.
- def chunker(seq, size):
+ def chunker(seq: list, size: int) -> list:
return (seq[pos : pos + size] for pos in range(0, len(seq), size))
c = 0
@@ -174,7 +175,7 @@ def chunker(seq, size):
def embed_these_chunks(
- chunks,
+ chunks: list,
vectordb_location: str,
embedding_import: ClassImportDefinition,
embedding_kwargs: dict,
diff --git a/quke/llm_chat.py b/quke/llm_chat.py
index db797af..1468ee6 100644
--- a/quke/llm_chat.py
+++ b/quke/llm_chat.py
@@ -1,9 +1,9 @@
"""Sets up all elements required for a chat session."""
import importlib
import logging # functionality managed by Hydra
-import os
from collections import defaultdict
from datetime import datetime
+from pathlib import Path
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
@@ -18,9 +18,9 @@ def chat(
embedding_import: ClassImportDefinition,
vectordb_import: ClassImportDefinition,
llm_import: ClassImportDefinition,
- llm_parameters,
- prompt_parameters,
- output_file,
+ llm_parameters: dict,
+ prompt_parameters: dict,
+ output_file: dict,
) -> object:
"""Initiates a chat with an LLM.
@@ -88,9 +88,9 @@ def chat_output(question: str, result: dict) -> None:
# TODO: Either I do not understand mdutils or it is an unfriendly package when trying to append.
-def chat_output_to_file(result: dict, output_file) -> None:
+def chat_output_to_file(result: dict, output_file: dict) -> None:
"""Populates a record of the chat with the LLM into a markdown file."""
- first_write = not os.path.isfile(output_file["path"])
+ first_write = not Path(output_file["path"]).is_file()
mdFile = MdUtils(file_name="tmp.md")
@@ -120,8 +120,18 @@ def chat_output_to_file(result: dict, output_file) -> None:
new.append_end((mdFile.get_md_text()).strip())
-def dict_crosstab(source, key, listed, missing="NA"):
- """Limited and simple version of a crosstab query on a dict."""
+def dict_crosstab(source: list, key: str, listed: str, missing: str = "NA") -> dict:
+ """Limited and simple version of a crosstab query on a dict.
+
+ Args:
+ source (list): _description_
+ key (str): _description_
+ listed (str): _description_
+ missing (str, optional): _description_. Defaults to "NA".
+
+ Returns:
+ _type_: _description_
+ """
dict_subs = []
for d in source:
dict_subs.append({key: d[key], listed: d.get(listed, missing)}.values())
diff --git a/quke/quke.py b/quke/quke.py
index 0792ff6..86c7f01 100644
--- a/quke/quke.py
+++ b/quke/quke.py
@@ -3,7 +3,7 @@
LLMs, embedding model, vector store and other components can be congfigured.
"""
import logging # functionality managed by Hydra
-import os
+from pathlib import Path
import hydra
from dotenv import find_dotenv, load_dotenv
@@ -34,12 +34,10 @@ def __init__(self, cfg: DictConfig) -> None:
self.src_doc_folder = cfg.source_document_folder
# TODO: Is this sufficiently robust? What if the user wants a folder not related to wcd/pwd?
- # Consider using os.path.isabs(path) or os.path.abspath(path). Or just pass the string and
- # 'Python will handle it'?
- self.vectordb_location = os.path.join(
- os.getcwd(),
- cfg.internal_data_folder,
- cfg.embedding.vectordb.vectorstore_location,
+ self.vectordb_location = str(
+ Path.cwd()
+ / cfg.internal_data_folder
+ / cfg.embedding.vectordb.vectorstore_location
)
self.embedding_import = ClassImportDefinition(
cfg.embedding.embedding.module_name, cfg.embedding.embedding.class_name
@@ -85,14 +83,15 @@ def __init__(self, cfg: DictConfig) -> None:
# TODO: need something better for output folder
# https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory/
try: # try statement done for testing suite
- self.output_file = os.path.join(
- hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"],
- cfg.experiment_summary_file,
+ self.output_file = str(
+ Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"])
+ / cfg.experiment_summary_file
)
+
except Exception:
self.output_file = cfg.experiment_summary_file
- def get_embed_params(self):
+ def get_embed_params(self) -> dict:
"""Based on the config files returns the set of parameters need to start embedding."""
embed_parameters = {
"src_doc_folder": self.src_doc_folder,
@@ -106,7 +105,7 @@ def get_embed_params(self):
}
return embed_parameters
- def get_chat_params(self):
+ def get_chat_params(self) -> dict:
"""Based on the config files returns the set of parameters need to start a chat."""
chat_parameters = {
"vectordb_location": self.vectordb_location,
@@ -126,11 +125,11 @@ def get_splitter_params(self) -> dict:
"splitter_args": self.splitter_args,
}
- def get_args_dict(self, cfg_sub):
+ def get_args_dict(self, cfg_sub: dict) -> dict:
"""Takes a subset of the Hydra configs and returns the same as a dict."""
return OmegaConf.to_container(cfg_sub, resolve=True)
- def get_llm_parameters(self):
+ def get_llm_parameters(self) -> dict:
"""Based on the config files returns the set of parameters needed to setup an LLM."""
return OmegaConf.to_container(self.cfg.llm.llm_args, resolve=True)
diff --git a/tests/test_001.py b/tests/test_001.py
index c131ddc..0dbcb01 100644
--- a/tests/test_001.py
+++ b/tests/test_001.py
@@ -1,7 +1,9 @@
import os
+from pathlib import Path
import pytest
from hydra import compose, initialize
+from omegaconf import DictConfig
from quke.embed import embed, get_chunks_from_pages, get_pages_from_document
from quke.llm_chat import chat
@@ -29,9 +31,10 @@ def GetConfigEmbedOnly():
@pytest.fixture(scope="session")
-def GetConfigLLMOnly(tmp_path_factory):
+def GetConfigLLMOnly(tmp_path_factory: pytest.TempPathFactory):
folder = tmp_path_factory.mktemp("output")
- output_file = os.path.join(folder, OUTPUT_FILE)
+ # output_file = os.path.join(folder, OUTPUT_FILE)
+ output_file = Path(folder) / OUTPUT_FILE
with initialize(version_base=None, config_path="./conf"):
cfg = compose(
config_name="config",
@@ -51,7 +54,7 @@ def GetPages() -> list:
@pytest.fixture(scope="session")
-def GetChunks(GetPages: list, GetConfigEmbedOnly) -> list:
+def GetChunks(GetPages: list, GetConfigEmbedOnly: DictConfig) -> list:
return get_chunks_from_pages(
GetPages, ConfigParser(GetConfigEmbedOnly).get_splitter_params()
)
@@ -67,7 +70,7 @@ def test_documentloader(GetPages: list):
assert text_file_found
-def test_getchunks(GetChunks):
+def test_getchunks(GetChunks: list):
assert len(GetChunks) == 1 # with current basic test just equals 1
@@ -84,7 +87,7 @@ def test_config():
@pytest.mark.expensive()
# Do the following to exlude this
# poetry run pytest -m 'not expensive'
-def test_embed(GetConfigEmbedOnly):
+def test_embed(GetConfigEmbedOnly: DictConfig):
chunks_embedded = embed(**ConfigParser(GetConfigEmbedOnly).get_embed_params())
assert chunks_embedded == 1
@@ -92,9 +95,9 @@ def test_embed(GetConfigEmbedOnly):
@pytest.mark.expensive()
# @pytest.mark.skipif(not os.path.exists(os.path.dirname(OUTPUT_FILE)),
# reason=f"Output folder {os.path.dirname(OUTPUT_FILE)} should exist before running test.")
-def test_chat(GetConfigLLMOnly):
+def test_chat(GetConfigLLMOnly: DictConfig):
from langchain.chains import ConversationalRetrievalChain
chat_result = chat(**ConfigParser(GetConfigLLMOnly).get_chat_params())
assert isinstance(chat_result, ConversationalRetrievalChain)
- assert os.path.isfile(ConfigParser(GetConfigLLMOnly).output_file)
+ assert Path(ConfigParser(GetConfigLLMOnly).output_file).is_file()