Skip to content

Commit

Permalink
Merge pull request #146 from ggozad/feat/save-images-in-messages
Browse files Browse the repository at this point in the history
Persist images in message history / sqlite db.
  • Loading branch information
ggozad authored Dec 29, 2024
2 parents 7458180 + 356ac8a commit 207fc55
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 40 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ Changelog
0.7.0 -
-------------------

- Enforce foreign key constraints in the sqlite db, to allow proper cascading deletes.
[ggozad]

- Perist images in the chat history & sqlite db.
[ggozad]

- Update OllamaLLM client to match the use of Pydantic in olllama-python.
[ggozad]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "oterm"
version = "0.6.9"
version = "0.7.0"
description = "A text-based terminal client for Ollama."
authors = [{ name = "Yiorgis Gozadinos", email = "[email protected]" }]
license = { text = "MIT" }
Expand Down
6 changes: 3 additions & 3 deletions src/oterm/app/chat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ async def on_submit(self, event: Input.Submitted) -> None:
if not event.value:
return

messages: Sequence[tuple[int, Author, str]] = await store.get_messages(
self.chat_id
messages: Sequence[tuple[int, Author, str, list[str]]] = (
await store.get_messages(self.chat_id)
)
with open(event.value, "w", encoding="utf-8") as file:
for message in messages:
_, author, text = message
_, author, text, images = message
file.write(f"*{author.value}*\n")
file.write(f"{text}\n")
file.write("\n---\n")
Expand Down
68 changes: 40 additions & 28 deletions src/oterm/app/widgets/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class ChatContainer(Widget):
ollama = OllamaLLM()
messages: reactive[list[tuple[int, Author, str]]] = reactive([])
messages: reactive[list[tuple[int, Author, str, list[str]]]] = reactive([])
chat_name: str
system: str | None
format: Literal["", "json"]
Expand All @@ -54,7 +54,7 @@ def __init__(
db_id: int,
chat_name: str,
model: str = "llama3.2",
messages: list[tuple[int, Author, str]] = [],
messages: list[tuple[int, Author, str, list[str]]] = [],
system: str | None = None,
format: Literal["", "json"] = "",
parameters: Options,
Expand All @@ -63,15 +63,18 @@ def __init__(
**kwargs,
) -> None:
super().__init__(*children, **kwargs)

history: list[Message] = [
(
Message(role="user", content=message)
if author == Author.USER
else Message(role="assistant", content=message)
history = []
# This is wrong, the images should be a list of Image objects
# See https://github.com/ollama/ollama-python/issues/375
# Temp fix is to do msg.images = images # type: ignore

for _, author, message, images in messages:
msg = Message(
role="user" if author == Author.USER else "assistant",
content=message,
)
for _, author, message in messages
]
msg.images = images # type: ignore
history.append(msg)

used_tool_defs = [
tool_def for tool_def in available_tool_defs if tool_def["tool"] in tools
Expand All @@ -83,7 +86,7 @@ def __init__(
format=format,
options=parameters,
keep_alive=keep_alive,
history=history,
history=history, # type: ignore
tool_defs=used_tool_defs,
)

Expand All @@ -96,6 +99,7 @@ def __init__(
self.keep_alive = keep_alive
self.tools = tools
self.loaded = False
self.images = []

def on_mount(self) -> None:
self.query_one("#prompt").focus()
Expand All @@ -104,7 +108,7 @@ async def load_messages(self) -> None:
if self.loaded:
return
message_container = self.query_one("#messageContainer")
for _, author, message in self.messages:
for _, author, message, images in self.messages:
chat_item = ChatItem()
chat_item.text = message
chat_item.author = author
Expand Down Expand Up @@ -156,25 +160,28 @@ async def response_task() -> None:
if message_container.can_view_partial(response_chat_item):
message_container.scroll_end()

self.images = []

# Save to db
store = await Store.get_store()
id = await store.save_message(
id=None,
chat_id=self.db_id,
author=Author.USER.value,
text=message,
images=[img for _, img in self.images],
)
self.messages.append(
(id, Author.USER, message, [img for _, img in self.images])
)
self.messages.append((id, Author.USER, message))

id = await store.save_message(
id=None,
chat_id=self.db_id,
author=Author.OLLAMA.value,
text=response,
)
self.messages.append((id, Author.OLLAMA, response))
self.messages.append((id, Author.OLLAMA, response, []))
self.images = []

except asyncio.CancelledError:
user_chat_item.remove()
response_chat_item.remove()
Expand Down Expand Up @@ -230,14 +237,18 @@ async def action_edit_chat(self) -> None:
)

# load the history from messages
history: list[Message] = [
(
Message(role="user", content=message)
if author == Author.USER
else Message(role="assistant", content=message)
history: list[Message] = []
# This is wrong, the images should be a list of Image objects
# See https://github.com/ollama/ollama-python/issues/375
# Temp fix is to do msg.images = images # type: ignore
for _, author, message, images in self.messages:
msg = Message(
role="user" if author == Author.USER else "assistant",
content=message,
)
for _, author, message in self.messages
]
msg.images = images # type: ignore
history.append(msg)

used_tool_defs = [
tool_def
for tool_def in available_tool_defs
Expand All @@ -250,7 +261,7 @@ async def action_edit_chat(self) -> None:
format=model["format"],
options=self.parameters,
keep_alive=model["keep_alive"],
history=history,
history=history, # type: ignore
tool_defs=used_tool_defs,
)

Expand Down Expand Up @@ -290,14 +301,13 @@ async def response_task() -> None:
response = ""
async for text in self.ollama.stream(
message,
self.images,
[img for _, img in self.images],
Options(seed=random.randint(0, 32768)),
):
response = text
response_chat_item.text = text
if message_container.can_view_partial(response_chat_item):
message_container.scroll_end()
self.images = []

# Save to db
store = await Store.get_store()
Expand All @@ -307,7 +317,9 @@ async def response_task() -> None:
author=Author.OLLAMA.value,
text=response,
)
self.messages.append((response_message_id, Author.OLLAMA, response))
self.messages.append((response_message_id, Author.OLLAMA, response, []))
self.images = []

loading.remove()

asyncio.create_task(response_task())
Expand All @@ -323,7 +335,7 @@ def on_history_selected(text: str | None) -> None:
prompt.focus()

prompts = [
message for _, author, message in self.messages if author == Author.USER
message for _, author, message, _ in self.messages if author == Author.USER
]
prompts.reverse()
screen = PromptHistory(prompts)
Expand Down
32 changes: 25 additions & 7 deletions src/oterm/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def get_store(cls) -> "Store":
"chat_id" INTEGER NOT NULL,
"author" TEXT NOT NULL,
"text" TEXT NOT NULL,
"images" TEXT DEFAULT "[]",
PRIMARY KEY("id" AUTOINCREMENT)
FOREIGN KEY("chat_id") REFERENCES "chat"("id") ON DELETE CASCADE
);
Expand Down Expand Up @@ -213,34 +214,51 @@ async def get_chat(

async def delete_chat(self, id: int) -> None:
async with aiosqlite.connect(self.db_path) as connection:
await connection.execute("PRAGMA foreign_keys = on;")
await connection.execute("DELETE FROM chat WHERE id = :id;", {"id": id})
await connection.commit()

async def save_message(
self, id: int | None, chat_id: int, author: str, text: str
self,
id: int | None,
chat_id: int,
author: str,
text: str,
images: list[str] = [],
) -> int:
async with aiosqlite.connect(self.db_path) as connection:
res = await connection.execute_insert(
"""
INSERT OR REPLACE
INTO message(id, chat_id, author, text)
VALUES(:id, :chat_id, :author, :text) RETURNING id;
INTO message(id, chat_id, author, text, images)
VALUES(:id, :chat_id, :author, :text, :images) RETURNING id;
""",
{"id": id, "chat_id": chat_id, "author": author, "text": text},
{
"id": id,
"chat_id": chat_id,
"author": author,
"text": text,
"images": json.dumps(images),
},
)
await connection.commit()
return res[0] if res else 0

async def get_messages(self, chat_id: int) -> list[tuple[int, Author, str]]:
async def get_messages(
self, chat_id: int
) -> list[tuple[int, Author, str, list[str]]]:

async with aiosqlite.connect(self.db_path) as connection:
messages = await connection.execute_fetchall(
"""
SELECT id, author, text
SELECT id, author, text, images
FROM message
WHERE chat_id = :chat_id;
""",
{"chat_id": chat_id},
)
messages = [(id, Author(author), text) for id, author, text in messages]
messages = [
(id, Author(author), text, json.loads(images))
for id, author, text, images in messages
]
return messages
2 changes: 2 additions & 0 deletions src/oterm/store/upgrades/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from oterm.store.upgrades.v0_4_0 import upgrades as v0_4_0_upgrades
from oterm.store.upgrades.v0_5_1 import upgrades as v0_5_1_upgrades
from oterm.store.upgrades.v0_6_0 import upgrades as v0_6_0_upgrades
from oterm.store.upgrades.v0_7_0 import upgrades as v0_7_0_upgrades

upgrades = (
v0_1_6_upgrades
Expand All @@ -18,4 +19,5 @@
+ v0_4_0_upgrades
+ v0_5_1_upgrades
+ v0_6_0_upgrades
+ v0_7_0_upgrades
)
37 changes: 37 additions & 0 deletions src/oterm/store/upgrades/v0_7_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path
from typing import Awaitable, Callable

import aiosqlite


async def images(db_path: Path) -> None:
async with aiosqlite.connect(db_path) as connection:
try:
await connection.executescript(
"""
ALTER TABLE message ADD COLUMN images TEXT DEFAULT "[]";
"""
)
except aiosqlite.OperationalError:
pass

await connection.commit()


async def orphan_messages(db_path: Path) -> None:
async with aiosqlite.connect(db_path) as connection:
try:
await connection.executescript(
"""
DELETE FROM message WHERE chat_id NOT IN (SELECT id FROM chat);
"""
)
except aiosqlite.OperationalError:
pass

await connection.commit()


upgrades: list[tuple[str, list[Callable[[Path], Awaitable[None]]]]] = [
("0.7.0", [images, orphan_messages]),
]
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 207fc55

Please sign in to comment.