diff --git a/poetry.lock b/poetry.lock index fff4018..a1ce859 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1258,6 +1258,23 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +optional = false +python-versions = ">=3.7" +files = [ + {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, + {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + [[package]] name = "langchain" version = "0.0.229" @@ -1402,6 +1419,65 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, +] + [[package]] name = "marshmallow" version = "3.20.1" @@ -1447,16 +1523,6 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] -[[package]] -name = "mdutils" -version = "1.6.0" -description = "Useful package for creating Markdown files while executing python code." -optional = false -python-versions = "^3.6" -files = [ - {file = "mdutils-1.6.0.tar.gz", hash = "sha256:647f3cf00df39fee6c57fa6738dc1160fce1788276b5530c87d43a70cdefdaf1"}, -] - [[package]] name = "monotonic" version = "1.6" @@ -3143,4 +3209,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "de1038901aa1a3fbb0227a7a99be9616df4558f604d0cd2e63bc76ce88d55ba8" +content-hash = "1b1f7097904aba38934895575efd54a135cb33ee18088ed1407f316f52d5ba79" diff --git a/pyproject.toml b/pyproject.toml index 1e97062..0e2a247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quke" -version = "0.2.1" +version = "0.3.1" description = "Compare the answering capabilities of different LLMs - for example LlaMa, ChatGPT, Cohere, Falcon - against user provided document(s) and questions." authors = ["Erik Oosterop"] maintainers = ["Erik Oosterop"] @@ -33,8 +33,8 @@ huggingface-hub = "^0.16.4" openai = "^0.27.8" cohere = "^4.17.0" replicate = "^0.9.0" -mdutils = "^1.6.0" rich = "^13.5.2" +jinja2 = "^3.1.2" [tool.poetry.group.dev.dependencies] pytest-cov = "^4.1.0" diff --git a/quke/conf/embedding/huggingface.yaml b/quke/conf/embedding/huggingface.yaml index c9054e7..7526e7a 100644 --- a/quke/conf/embedding/huggingface.yaml +++ b/quke/conf/embedding/huggingface.yaml @@ -4,11 +4,11 @@ vectordb: vectorstore_location: vector_store/chromadb_hf_del # Possible values for vectorstore_write_mode: overwrite, no_overwrite, append - # This works at the vectorstore_location level. + # This works at the vectorstore_location level. # -If the folder exists and 'no_overwrite' is specified: document will not be embedded # -If the folder exists and 'overwrite' is specified, all contents of the vectordb folder will be deleted and a new vectordb will be created. # -If set to 'append' the new embeddings will be appended to any existing vectordb. If a source document is specified twice it will be embedded twice. - vectorstore_write_mode: overwrite + vectorstore_write_mode: no_overwrite embedding: module_name: langchain.embeddings diff --git a/quke/llm_chat.py b/quke/llm_chat.py index 97dd269..6b76449 100644 --- a/quke/llm_chat.py +++ b/quke/llm_chat.py @@ -4,11 +4,11 @@ from collections import defaultdict from datetime import datetime from pathlib import Path +from typing import Literal +from jinja2 import Environment, PackageLoader, select_autoescape from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory -from mdutils.fileutils import MarkDownFile # type: ignore -from mdutils.mdutils import MdUtils # type: ignore from . import ClassImportDefinition @@ -71,71 +71,72 @@ def chat( # NOTE: trial API keys may have very restrictive rules. It is plausible that you run into # constraints after the 2nd question. - for question in prompt_parameters: - result = qa({"question": question}) - chat_output(result) - chat_output_to_file(result, output_file) + results = [qa({"question": question}) for question in prompt_parameters] + chat_output_to_html( + results, output_file + ) # TODO: infer output from output file name in cfg? + chat_output_to_html(results, output_file, output_extension=".md") + chat_output_to_html(results, output_file, output_extension="logging") logging.info("=======================") return qa -def chat_output(result: dict) -> None: - """Logs a chat question and anwer. +def chat_output_to_html( + results: list[dict], + output_file: dict, + output_extension: Literal[".html", ".md", "logging"] = ".html", +) -> None: + """Write summary of chat experiment into HTML file. Args: - result: dict with the answer from the LLM. Expects 'question', 'answer' and 'source' keys, - 'page' key optionally. + results: list of dicts with the answer from the LLM. Expects 'question', 'answer' + and 'source' keys; 'page' key optionally. + output_file: path and other information regarding the output file. + output_extension: .html or .md. Alteratively logging for python logging. """ - logging.info("=======================") - logging.info(f"Q: {result['question']}") - logging.info(f"A: {result['answer']}") - - src_docs = [doc.metadata for doc in result["source_documents"]] - src_docs_pages_used = dict_crosstab(src_docs, "source", "page") - for key, value in src_docs_pages_used.items(): - logging.info(f"Source document: {key}, Pages used: {value}") + env = Environment(loader=PackageLoader("quke"), autoescape=select_autoescape()) + + if output_extension.lower() == ".html": + template_name = "chat_session.html.jinja" + elif output_extension.lower() == ".md": + template_name = "chat_session.md.jinja" + elif output_extension.lower() == "logging": + template_name = "chat_session.logging.jinja" + else: + template_name = "chat_session.html.jinja" + template = env.get_template(template_name) + func_dict = {"dict_crosstab": _dict_crosstab_for_jinja} + template.globals.update(func_dict) -# TODO: Either I do not understand mdutils or it is an unfriendly package when trying to append. -def chat_output_to_file(result: dict, output_file: dict) -> None: - """Populates a record of the chat with the LLM into a markdown file. + output = template.render( + chat_time=datetime.now().astimezone().strftime("%a %d-%b-%Y %H:%M %Z"), + llm_results=results, + config=output_file["conf_yaml"], + ) - Args: - result: dict with the answer from the LLM. Expects 'question', 'answer' and 'source' keys, - 'page' key optionally. - output_file: File name to which the record is saved. - """ - first_write = not Path(output_file["path"]).is_file() - - md_file = MdUtils(file_name="tmp.md") - - if first_write: - md_file.new_header(1, "LLM Chat Session with quke") - md_file.write( - datetime.now().astimezone().strftime("%a %d-%b-%Y %H:%M %Z"), align="center" - ) - md_file.new_paragraph("") - md_file.new_header(2, "Experiment settings", header_id="settings") - md_file.insert_code(output_file["conf_yaml"], language="yaml") - md_file.new_header(2, "Chat", header_id="chat") + if output_extension.lower() == "logging": + logging.info(output) else: - existing_text = MarkDownFile().read_file(file_name=output_file["path"]) - md_file.new_paragraph(existing_text) + file_path = Path(output_file["path"]).with_suffix(output_extension) + with file_path.open("w") as fp: + fp.write(output) - md_file.new_paragraph(f"Q: {result['question']}") - md_file.new_paragraph(f"A: {result['answer']}") - src_docs = [doc.metadata for doc in result["source_documents"]] - src_docs_pages_used = dict_crosstab(src_docs, "source", "page") - for key, value in src_docs_pages_used.items(): - md_file.new_paragraph(f"Source document: {key}, Pages used: {value}") +def _dict_crosstab_for_jinja(sources: list) -> dict: + """Wrapper around dict_crostab for use from within Jinja. - new = MarkDownFile(name=output_file["path"]) + Args: + sources (list): _description_ - new.append_end((md_file.get_md_text()).strip()) + Returns: + dict: _description_ + """ + src_docs = [doc.metadata for doc in sources] + return dict_crosstab(src_docs, "source", "page") def dict_crosstab(source: list, key: str, listed: str, missing: str = "NA") -> dict: diff --git a/quke/quke.py b/quke/quke.py index fed88dc..a47b82a 100644 --- a/quke/quke.py +++ b/quke/quke.py @@ -166,7 +166,7 @@ def quke(cfg: DictConfig) -> None: with console.status("Embedding...", spinner="aesthetic"): # python -m rich.spinner to see options embed.embed(**embed_parameters) - logging.info("\n" + OmegaConf.to_yaml(cfg)) + # Used to log config here: logging.info("\n" + OmegaConf.to_yaml(cfg)) if not config_parser.embed_only: with console.status("Chatting...", spinner="aesthetic"): diff --git a/quke/templates/chat_session.html.jinja b/quke/templates/chat_session.html.jinja new file mode 100644 index 0000000..892f750 --- /dev/null +++ b/quke/templates/chat_session.html.jinja @@ -0,0 +1,31 @@ + + + + + LLM Chat Session with quke + + + +

LLM Chat Session with quke

+
{{ chat_time }}
+

Experiment Settings

+
+        {{ config }}
+    
+

Chat

+ {% for result in llm_results %} +
Q: {{ result.question }}
+
A: {{ result.answer }}
+
+
Source:
+
+ {% for key, value in dict_crosstab(result.source_documents).items() %} + {{ key }}, pages: {{ value }} + {% endfor %} +
+

+ {% endfor %} + + +{# 1) timestamp 2) conf summary 3) chat: [question, answer, source (optional)] #} + \ No newline at end of file diff --git a/quke/templates/chat_session.logging.jinja b/quke/templates/chat_session.logging.jinja new file mode 100644 index 0000000..c27083e --- /dev/null +++ b/quke/templates/chat_session.logging.jinja @@ -0,0 +1,8 @@ +======================= +{{ config }} +{% for result in llm_results %} +Q: {{ result.question }} +A: {{ result.answer }} +Source: {% for key, value in dict_crosstab(result.source_documents).items() %} + document: {{ key }}, page: {{ value }} {% endfor %} +{% endfor %}======================= \ No newline at end of file diff --git a/quke/templates/chat_session.md.jinja b/quke/templates/chat_session.md.jinja new file mode 100644 index 0000000..99f095b --- /dev/null +++ b/quke/templates/chat_session.md.jinja @@ -0,0 +1,22 @@ +# LLM Chat Session with quke +
{{ chat_time }}
+ +## Experiment settings + +```yaml +{{ config }} +``` + +## Chat + +{% for result in llm_results %} +Q: {{ result.question }} + +A: {{ result.answer }} + +Source: {% for key, value in dict_crosstab(result.source_documents).items() %} +{{ key }}, pages: {{ value }} +{% endfor %} +------- + +{% endfor %} \ No newline at end of file diff --git a/tests/test_001.py b/tests/test_001.py index e6184df..172b246 100644 --- a/tests/test_001.py +++ b/tests/test_001.py @@ -108,7 +108,10 @@ def test_chat(GetConfigLLMOnly: DictConfig): chat_result = chat(**ConfigParser(GetConfigLLMOnly).get_chat_params()) assert isinstance(chat_result, ConversationalRetrievalChain) - assert Path(ConfigParser(GetConfigLLMOnly).output_file).is_file() + assert ( + Path(ConfigParser(GetConfigLLMOnly).output_file).with_suffix(".html").is_file() + or Path(ConfigParser(GetConfigLLMOnly).output_file).with_suffix(".md").is_file() + ) def test_crosstab_dict(GetCrossTabDicts: list):