From 7b8ae43d414ec9a6cd82209ef89057254a6d8612 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:09:28 -0600 Subject: [PATCH 1/6] migrate initial server over --- topos/chat_api/server.py | 440 ++++++++++-------- topos/services/messages/__init__.py | 0 .../messages/group_management_service.py | 49 ++ .../services/messages/group_manager_sqlite.py | 247 ++++++++++ .../messages/missed_message_manager.py | 64 +++ .../messages/missed_message_service.py | 22 + topos/utilities/utils.py | 23 + 7 files changed, 654 insertions(+), 191 deletions(-) create mode 100644 topos/services/messages/__init__.py create mode 100644 topos/services/messages/group_management_service.py create mode 100644 topos/services/messages/group_manager_sqlite.py create mode 100644 topos/services/messages/missed_message_manager.py create mode 100644 topos/services/messages/missed_message_service.py diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index 1424e4f..b740501 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -1,233 +1,291 @@ -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends -from fastapi.middleware.cors import CORSMiddleware -from starlette.websockets import WebSocketState -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Tuple -import uvicorn -import json import asyncio import datetime -from ..utilities.utils import generate_deci_code +import json +from typing import Dict, List +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from aiokafka import AIOKafkaProducer, AIOKafkaConsumer +from fastapi.concurrency import asynccontextmanager +from services.group_management_service import GroupManagementService +from services.missed_message_service import MissedMessageService +from services.startup.initialize_database import init_sqlite_database,ensure_file_exists +from utils.utils import generate_deci_code, generate_group_name +from pydantic import BaseModel +# MissedMessageRequest model // subject to change +class MissedMessagesRequest(BaseModel): + user_id: str + # last_sync_time: datetime -app = FastAPI() +# Kafka configuration +KAFKA_BOOTSTRAP_SERVERS = 'localhost:9092' +KAFKA_TOPIC = 'chat_topic' -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Change this to the specific origins you want to allow - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -class SessionManager: +# WebSocket connection manager +class ConnectionManager: def __init__(self): - self.active_sessions: Dict[str, List[Tuple[str, WebSocket]]] = {} - self.user_sessions: Dict[str, str] = {} - self.usernames: Dict[str, str] = {} + self.active_connections: Dict[WebSocket, str] = {} + + async def connect(self, websocket: WebSocket): + await websocket.accept() + + def register(self, websocket: WebSocket,user_id:str): + self.active_connections[websocket] = user_id - def add_session(self, session_id: str, user_id: str, websocket: WebSocket): - print(f"[ adding {session_id} to active_sessions ]") - if session_id not in self.active_sessions: - self.active_sessions[session_id] = [] - self.active_sessions[session_id].append((user_id, websocket)) + async def disconnect(self, websocket: WebSocket): + user_id = self.active_connections[websocket] + user = group_management_service.get_user_by_id(user_id) + username = user['username'] + group_management_service.set_user_last_seen_online(user_id) + users_list_of_groups = group_management_service.get_user_groups(user_id) + for group in users_list_of_groups: + + disconnect_message = f"{username} left the chat" + message = { + "message_id": generate_deci_code(16), + "message_type": "server", + "username": username, + "from_user_id":user_id, + "session_id": group["group_id"], + "message": disconnect_message, + "timestamp": datetime.datetime.now().isoformat() + } + await producer.send_and_wait(KAFKA_TOPIC, key=group["group_id"].encode('utf-8'),value=json.dumps(message).encode('utf-8')) + print(f"removing {user_id}") + try: + del self.active_connections[websocket] + print("Successfully Disconnected") + except: + print("Disconnect Failed. Error encountered. Check Logs") - def get_active_sessions(self): - return self.active_sessions + async def broadcast(self, from_user_id:str,message, group_id:str):# + print(message) + if(group_management_service.get_group_by_id(group_id=group_id)): + print(f"{group_id} exists" ) + group_users_info = group_management_service.get_group_users(group_id=group_id) + group_user_ids = [user['user_id'] for user in group_users_info] + print(group_users_info) + for connection,user_id in self.active_connections.items(): + print(f"Testing {user_id}") + if(user_id in group_user_ids): + print(f"{user_id} is in {group_id}") + print(f"sending to {user_id}") + print(f"connection state is {connection.application_state}") + print(f"Sending message: {message}") + await connection.send_json(message) + print("next connection") + - def get_user_sessions(self): - return self.user_sessions +group_management_service = GroupManagementService() +manager = ConnectionManager() - def add_user_session(self, user_id: str, session_id: str): - print(f"[ adding {user_id} to user_sessions ]") - self.user_sessions[user_id] = session_id +producer = None +consumer = None - def add_username(self, user_id: str, username: str): - print(f"[ adding {username} for {user_id} ]") - self.usernames[user_id] = username + +async def consume_messages(): + async for msg in consumer: + # print(msg.offset) + message = json.loads(msg.value.decode('utf-8')) + group_id = msg.key.decode('utf-8') + await manager.broadcast(message=message,from_user_id=message["from_user_id"],group_id=group_id) + - def get_username(self, user_id: str) -> str: - return self.usernames.get(user_id, "Unknown user") +@asynccontextmanager +async def lifespan(app: FastAPI): + # Load the ML model + # Kafka producer + global producer + global consumer + ensure_file_exists("../db/user.db") + init_sqlite_database("../db/user.db") + producer = AIOKafkaProducer(bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS) -session_manager = SessionManager() + # Kafka consumer + consumer = AIOKafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS, + # group_id="chat_group" + ) -async def send_message_to_client(client: WebSocket, message: dict): - if not isinstance(message, dict): - print("Message is not a dictionary") - return + await producer.start() + await consumer.start() + # https://stackoverflow.com/questions/46890646/asyncio-weirdness-of-task-exception-was-never-retrieved + # we need to keep a reference of this task alive else it will stop the consume task, there has to be a live refference for this to work + consume_task = asyncio.create_task(consume_messages()) + yield + # Clean up the ML models and release the resources + consume_task.cancel() + await producer.stop() + await consumer.stop() - if not client.application_state == WebSocketState.CONNECTED: - print("Client is not connected") - return +# FastAPI app +app = FastAPI(lifespan=lifespan) - try: - await client.send_json(message) - except Exception as e: - print(e) - -async def send_message_to_all(session_id: str, sender_user_id: str, message: dict, session_manager: SessionManager): - active_sessions = session_manager.get_active_sessions() - print("send_message_to_all") - print(session_id in active_sessions) - if message['message_type'] != 'server': - print(f"[ message to user :: {message['content']['text']}]") - if session_id in active_sessions: - for user_id, client in active_sessions[session_id]: - if message['message_type'] == 'server': - await send_message_to_client(client, message) - elif user_id != sender_user_id: - await send_message_to_client(client, message) - -async def send_to_all_clients_on_all_sessions(sender_user_id: str, message: dict, session_manager: SessionManager): - active_sessions = session_manager.get_active_sessions() - print("send_message_to_all") - if message['message_type'] != 'server': - print(f"[ message to user :: {message['content']['text']}]") - for session_id in active_sessions: - message["session_id"] = session_id - for user_id, client in active_sessions[session_id]: - if message['message_type'] == 'server': - await send_message_to_client(client, message) - elif user_id != sender_user_id: - await send_message_to_client(client, message) - -async def handle_client(websocket: WebSocket, session_manager: SessionManager, inactivity_event: asyncio.Event): - await websocket.accept() - print("client joined") +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await manager.connect(websocket) + print("client joined") # think about what needs to be done when the client joins like missed message services etc try: while True: - data = await asyncio.wait_for(websocket.receive_text(), timeout=600.0) # removes user if they haven't spoken in 10 minutes + data = await websocket.receive_text() if data: payload = json.loads(data) - inactivity_event.set() # Reset the inactivity event print(payload) + # get user details and associate with ws? we a separate client message to declare its identity after a potential disconnect, 1 time after + manager.register(websocket,payload['user_id']) + if(group_management_service.get_user_by_id(payload["user_id"])== None): + group_management_service.create_user(username=payload["username"],user_id=payload["user_id"]) message_type = payload['message_type'] print(message_type) - active_sessions = session_manager.get_active_sessions() - user_sessions = session_manager.get_user_sessions() - + group_management_service.set_user_last_seen_online(payload['user_id']) if message_type == "create_server": - session_id = generate_deci_code(6) - print(f"[ client created chat :: session_id {session_id} ]") - user_id = payload['user_id'] - host_name = payload['host_name'] - username = payload['username'] - session_manager.add_session(session_id, user_id, websocket) - session_manager.add_user_session(user_id, session_id) - session_manager.add_username(user_id, username) - print(session_manager.get_active_sessions()) # shows value - active_sessions = session_manager.get_active_sessions() - - prompt_message = f"{host_name} created the chat" - data = { + group_name = generate_group_name() # you can create group name on the frontend , this is just a basic util that can be swapped out if needed + group_id = group_management_service.create_group(group_name=group_name) + group_management_service.add_user_to_group(payload["user_id"],group_id=group_id) + print(f"[ client created chat :: group : {"group_name " + group_name + " : gid:"+ group_id} ]") + prompt_message = f"{payload["username"]} created the chat" + message = { "message_type": "server", - "session_id": session_id, + "from_user_id": payload["user_id"], + "session_id": group_id, "message": prompt_message, "timestamp": datetime.datetime.now().isoformat() } - await send_message_to_all(session_id, user_id, data, session_manager) - + await producer.send_and_wait(KAFKA_TOPIC, key=group_id.encode('utf-8'),value=json.dumps(message).encode('utf-8')) elif message_type == "join_server": - session_id = payload['session_id'] + group_id = payload['session_id'] user_id = payload['user_id'] username = payload['username'] - active_sessions = session_manager.get_active_sessions() - print(session_id) - print("ACTIVE SESSIONS: ", session_manager.get_active_sessions()) - print("ACTIVE SESSIONS: ", active_sessions) # shows empty when client connects - print(session_id in active_sessions) - if session_id in active_sessions: - print(f"[ {username} joined chat :: session_id {session_id} ]") - session_manager.add_session(session_id, user_id, websocket) - session_manager.add_user_session(user_id, session_id) - session_manager.add_username(user_id, username) + # see if session exists + if(group_management_service.get_group_by_id(group_id=group_id) == None): + await websocket.send_json({"error": "Invalid session"}) + else: + group_management_service.add_user_to_group(user_id=user_id,group_id=group_id) join_message = f"{username} joined the chat" - data = { - "message_type": "server", - "session_id": session_id, - "message": join_message, - "timestamp": datetime.datetime.now().isoformat() - } - await send_message_to_all(session_id, user_id, data, session_manager) + print("Hells bells") + print(join_message) + message = { + "message_type": "server", + "from_user_id": payload["user_id"], + "session_id": group_id, + "message": join_message, + "timestamp": datetime.datetime.now().isoformat() + } + await producer.send_and_wait(KAFKA_TOPIC, key=group_id.encode('utf-8'),value= json.dumps(message).encode('utf-8')) + else: + print("RECEIVED: ", payload) + group_id = payload['session_id'] + user_id = payload['user_id'] + message_id = payload['message_id'] # generate_deci_code(16) + user = group_management_service.get_user_by_id(user_id) + message = { + "message_type": "user", + "message_id": message_id, + "from_user_id": user_id, + "username": user['username'], + "session_id": group_id, + "message": payload["content"]["text"], + "timestamp": datetime.datetime.now().isoformat() + } + if (group_management_service.get_group_by_id(group_id=group_id)): + print(f"sending {group_id}") + await producer.send_and_wait(KAFKA_TOPIC, key=group_id.encode('utf-8'),value=json.dumps(message).encode('utf-8')) else: - await websocket.send_json({"error": "Invalid session ID"}) - break - while True: - data = await websocket.receive_text() - if data: - payload = json.loads(data) - inactivity_event.set() # Reset the inactivity event - print("RECEIVED: ", payload) - session_id = payload['content']['session_id'] - user_id = payload['content']['user_id'] - if session_id: - print(f"sending {session_id}") - await send_message_to_all(session_id, user_id, payload, session_manager) - else: - print(f"[ Message from client is empty ]") + await websocket.send_json({"error": "Invalid session"}) + except WebSocketDisconnect: - print("client disconnected") - await handle_disconnect(websocket, session_manager) + await manager.disconnect(websocket) except asyncio.TimeoutError: print("client disconnected due to timeout") - await handle_disconnect(websocket, session_manager) + await manager.disconnect(websocket) except Exception as e: print(f"client disconnected due to error: {e}") - await handle_disconnect(websocket, session_manager) - -async def handle_disconnect(websocket, session_manager): - active_sessions = session_manager.get_active_sessions() - user_sessions = session_manager.get_user_sessions() - for session_id, clients in active_sessions.items(): - for user_id, client in clients: - if client == websocket: - clients.remove((user_id, client)) - if not clients: - del active_sessions[session_id] - username = session_manager.get_username(user_id) - disconnect_message = f"{username} left the chat" - await asyncio.shield(send_message_to_all(session_id, user_id, { - "message_type": "server", - "session_id": session_id, - "message": disconnect_message, - "timestamp": datetime.datetime.now().isoformat() - }, session_manager)) - break - user_sessions.pop(user_id, None) - session_manager.usernames.pop(user_id, None) - -@app.websocket("/ws/chat") -async def websocket_endpoint(websocket: WebSocket): - print("[ client connected :: preparing setup ]") - print(f" current connected sessions :: {session_manager.get_active_sessions()}") - inactivity_event = asyncio.Event() - # inactivity_task = asyncio.create_task(check_inactivity(inactivity_event)) # not applicable for local builds - await handle_client(websocket, session_manager, inactivity_event) - # inactivity_task.cancel() - -async def check_inactivity(inactivity_event: asyncio.Event): - while True: - try: - await asyncio.wait_for(inactivity_event.wait(), timeout=600.0) - inactivity_event.clear() - except asyncio.TimeoutError: - print("No activity detected for 10 minutes, shutting down...") - disconnect_message = f"Conserving power...shutting down..." - await asyncio.shield(send_to_all_clients_on_all_sessions("senderUSErID#45", - { - "message_type": "server", - "message": disconnect_message, - "timestamp": datetime.datetime.now().isoformat() - }, session_manager)) - asyncio.get_event_loop().stop() - -# perform healthcheck for GCP requirement -@app.get("/healthcheck/") + await manager.disconnect(websocket) + + +@app.get("/") async def root(): - return {"message": "Status: OK"} + return {"message": "Welcome to the FastAPI aiokafka WebSocket Chat Server"} -@app.post("/test") -async def test(): - return {"response": True} +@app.post("/chat/missed-messages") +async def get_missed_messages(request: MissedMessagesRequest): + # get the user id and the pass it to the missed message service and then invoke it + missed_message_service = MissedMessageService() + return await missed_message_service.get_missed_messages(user_id=request.user_id,group_management_service=group_management_service) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) \ No newline at end of file + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=13394) + + +""" Message JSON Schema + +{ +"message_id": "", +"message_type": "", // OPTIONS: user, ai, server +“num_participants”: , +"content": + { + "sender_id": "", + "conversation_id": "", + "username": "", + "text": "" + }, +"timestamp": "", +"metadata": { + "priority": "", // e.g., normal, high + "tags": ["", ""] + }, +"attachments": [ + { + "file_name": "", + "file_type": "", + "url": "" + } +] } + +""" +""" +create server message +{ + "message_type": "create_server", + "num_participants": "5", + "host_name": "anshul", + "user_id": "1", + "created_at": "t0", + "username": "anshul" +} +""" + +""" +Revising the message format +{ + "message_id": "69", + "message_type": "user", + "user_id": "2", + "username": "jonny", + "session_id":"961198", + "content": + { + "metadata": { + "priority": "", + "tags": ["", ""] + }, + "attachments": [ + { + "file_name": "", + "file_type": "", + "url": "" + } + ], + "text": "kafka chatserver works" + }, + "timestamp": "t5" +} +""" + +""" +Notes: +WE do not need to pass on any information like username instead it should probably be display name associated with a specific group +Right now it is being treated as a solid username and not display name +""" \ No newline at end of file diff --git a/topos/services/messages/__init__.py b/topos/services/messages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/topos/services/messages/group_management_service.py b/topos/services/messages/group_management_service.py new file mode 100644 index 0000000..77fdda8 --- /dev/null +++ b/topos/services/messages/group_management_service.py @@ -0,0 +1,49 @@ +from typing import List, Optional +from managers.group_manager_sqlite import GroupManagerSQLite +class GroupManagementService: + def __init__(self) -> None: + self.group_manager = GroupManagerSQLite() # this implementation can be swapped for oother implementations out based on env var, use if statements + # any other house keeping can be done here too + + def create_group(self, group_name: str) -> str: + return self.group_manager.create_group(group_name=group_name) + + def create_user(self, user_id:str,username: str) -> str: + return self.group_manager.create_user(user_id,username) + + def add_user_to_group(self, user_id: str, group_id: str) -> bool: + return self.group_manager.add_user_to_group(user_id=user_id,group_id=group_id) + + def remove_user_from_group(self, user_id: str, group_id: str) -> bool: + return self.group_manager.remove_user_from_group(user_id=user_id,group_id=group_id) + + def get_user_groups(self, user_id: str) -> List[dict]: + return self.group_manager.get_user_groups(user_id) + + def get_group_users(self, group_id: str) -> List[dict]: + return self.group_manager.get_group_users(group_id) + + def get_group_by_id(self, group_id: str) -> Optional[dict]: + return self.group_manager.get_group_by_id(group_id) + + def get_user_by_id(self, user_id: str) -> Optional[dict]: + return self.group_manager.get_user_by_id(user_id) + + def get_group_by_name(self, group_name: str) -> Optional[dict]: + return self.group_manager.get_group_by_name(group_name) + + def get_user_by_username(self, username: str) -> Optional[dict]: + return self.get_user_by_username(username) + + def delete_group(self, group_id: str) -> bool: + return self.group_manager.delete_group(group_id) + + def delete_user(self, user_id: str) -> bool: + return self.group_manager.delete_user(user_id) + + def set_user_last_seen_online(self,user_id:str)-> bool: + return self.group_manager.set_user_last_seen_online(user_id) + + def get_user_last_seen_online(self,user_id:str)-> bool: + return self.group_manager.get_user_last_seen_online(user_id) + diff --git a/topos/services/messages/group_manager_sqlite.py b/topos/services/messages/group_manager_sqlite.py new file mode 100644 index 0000000..99023ca --- /dev/null +++ b/topos/services/messages/group_manager_sqlite.py @@ -0,0 +1,247 @@ +from datetime import datetime +import sqlite3 +import uuid +from typing import List, Optional, Dict +from utils.utils import generate_deci_code +class GroupManagerSQLite: + def __init__(self, db_file: str = '../db/user.db'): + self.db_file = db_file + + # Initialize empty caches + self.groups_cache: Dict[str, Dict] = {} # group_id -> group_info + self.users_cache: Dict[str, Dict] = {} # user_id -> user_info + self.user_groups_cache: Dict[str, List[str]] = {} # user_id -> list of group_ids + self.group_users_cache: Dict[str, List[str]] = {} # group_id -> list of user_ids + + def _get_group_from_db(self, group_id: str) -> Optional[Dict]: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT group_id, group_name FROM groups WHERE group_id = ?', (group_id,)) + result = cursor.fetchone() + print(result) + if result: + return {"group_id": result[0], "group_name": result[1]} + return None + + def _get_user_from_db(self, user_id: str) -> Optional[Dict]: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT user_id, username FROM users WHERE user_id = ?', (user_id,)) + result = cursor.fetchone() + if result: + return {"user_id": result[0], "username": result[1]} + return None + + def _get_user_groups_from_db(self, user_id: str) -> List[str]: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT group_id FROM user_groups WHERE user_id = ?', (user_id,)) + return [row[0] for row in cursor.fetchall()] + + def _get_group_users_from_db(self, group_id: str) -> List[str]: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT user_id FROM user_groups WHERE group_id = ?', (group_id,)) + return [row[0] for row in cursor.fetchall()] + + def create_group(self, group_name: str) -> str: + group_id = generate_deci_code(6) + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('INSERT INTO groups (group_id, group_name) VALUES (?, ?)', (group_id, group_name)) + conn.commit() + + # Update cache + self.groups_cache[group_id] = {"group_id": group_id, "group_name": group_name} + self.group_users_cache[group_id] = [] + + return group_id + + def create_user(self, user_id:str,username: str,) -> str: + + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('INSERT INTO users (user_id, username) VALUES (?, ?)', (user_id, username)) + conn.commit() + + # Update cache + self.users_cache[user_id] = {"user_id": user_id, "username": username} + self.user_groups_cache[user_id] = [] + + return user_id + + def add_user_to_group(self, user_id: str, group_id: str) -> bool: + try: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('INSERT INTO user_groups (user_id, group_id) VALUES (?, ?)', (user_id, group_id)) + conn.commit() + + # Update cache if the entries exist + if user_id in self.user_groups_cache: + self.user_groups_cache[user_id].append(group_id) + if group_id in self.group_users_cache: + self.group_users_cache[group_id].append(user_id) + + return True + except sqlite3.IntegrityError: + return False # User already in group or user/group doesn't exist + + def remove_user_from_group(self, user_id: str, group_id: str) -> bool: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM user_groups WHERE user_id = ? AND group_id = ?', (user_id, group_id)) + conn.commit() + if cursor.rowcount > 0: + # Update cache if the entries exist + if user_id in self.user_groups_cache: + self.user_groups_cache[user_id].remove(group_id) + if group_id in self.group_users_cache: + self.group_users_cache[group_id].remove(user_id) + return True + return False + + def get_user_groups(self, user_id: str) -> List[dict]: + if user_id not in self.user_groups_cache: + self.user_groups_cache[user_id] = self._get_user_groups_from_db(user_id) + + return [self.get_group_by_id(group_id) for group_id in self.user_groups_cache[user_id]] + + def get_group_users(self, group_id: str) -> List[dict]: + if group_id not in self.group_users_cache: + self.group_users_cache[group_id] = self._get_group_users_from_db(group_id) + + return [self.get_user_by_id(user_id) for user_id in self.group_users_cache[group_id]] + + def get_group_by_id(self, group_id: str) -> Optional[dict]: + if group_id not in self.groups_cache: + group = self._get_group_from_db(group_id) + if group: + self.groups_cache[group_id] = group + else: + return None + return self.groups_cache[group_id] + + def get_user_by_id(self, user_id: str) -> Optional[dict]: + if user_id not in self.users_cache: + user = self._get_user_from_db(user_id) + if user: + self.users_cache[user_id] = user + else: + return None + return self.users_cache[user_id] + + def get_group_by_name(self, group_name: str) -> Optional[dict]: + # This operation requires a full DB scan if not in cache + for group in self.groups_cache.values(): + if group['group_name'] == group_name: + return group + + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT group_id, group_name FROM groups WHERE group_name = ?', (group_name,)) + result = cursor.fetchone() + if result: + group = {"group_id": result[0], "group_name": result[1]} + self.groups_cache[group['group_id']] = group + return group + return None + + def get_user_by_username(self, username: str) -> Optional[dict]: + # This operation requires a full DB scan if not in cache + for user in self.users_cache.values(): + if user['username'] == username: + return user + + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('SELECT user_id, username FROM users WHERE username = ?', (username,)) + result = cursor.fetchone() + if result: + user = {"user_id": result[0], "username": result[1]} + self.users_cache[user['user_id']] = user + return user + return None + + def delete_group(self, group_id: str) -> bool: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM user_groups WHERE group_id = ?', (group_id,)) + cursor.execute('DELETE FROM groups WHERE group_id = ?', (group_id,)) + conn.commit() + if cursor.rowcount > 0: + # Update cache + self.groups_cache.pop(group_id, None) + self.group_users_cache.pop(group_id, None) + for user_groups in self.user_groups_cache.values(): + if group_id in user_groups: + user_groups.remove(group_id) + return True + return False + + def delete_user(self, user_id: str) -> bool: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM user_groups WHERE user_id = ?', (user_id,)) + cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,)) + conn.commit() + if cursor.rowcount > 0: + # Update cache + self.users_cache.pop(user_id, None) + self.user_groups_cache.pop(user_id, None) + for group_users in self.group_users_cache.values(): + if user_id in group_users: + group_users.remove(user_id) + return True + return False + def get_user_last_seen_online(self, user_id: str) -> str: + """ + Get the last_seen_online timestamp for a given user_id. + + :param user_id: The ID of the user + :return: The last seen timestamp as a string, or None if user not found + """ + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute(''' + SELECT last_seen_online + FROM users + WHERE user_id = ? + ''', (user_id,)) + + result = cursor.fetchone() + if result: + return result[0] + else: + print(f"User with ID {user_id} not found.") + return None + except sqlite3.Error as e: + print(f"An error occurred: {e}") + return None + + def set_user_last_seen_online(self, user_id: str) -> bool: + """ + Set the last_seen_online timestamp for a given user_id to the current time. + + :param user_id: The ID of the user + :return: True if successful, False if user not found + """ + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute(''' + UPDATE users + SET last_seen_online = ? + WHERE user_id = ? + ''', (datetime.now().replace(microsecond=0), user_id)) + + if cursor.rowcount == 0: + print(f"User with ID {user_id} not found.") + return False + + return True + except sqlite3.Error as e: + print(f"An error occurred: {e}") + return False + diff --git a/topos/services/messages/missed_message_manager.py b/topos/services/messages/missed_message_manager.py new file mode 100644 index 0000000..da32846 --- /dev/null +++ b/topos/services/messages/missed_message_manager.py @@ -0,0 +1,64 @@ +import asyncio +import json +from aiokafka import AIOKafkaConsumer, TopicPartition +from typing import List, Set, Dict, Any + +KAFKA_BOOTSTRAP_SERVERS = 'localhost:9092' +KAFKA_TOPIC = 'chat_topic' +class MissedMessageManager: + + async def get_filtered_missed_messages(self, + timestamp_ms: int, + key_filter: Set[str] + # max_messages: int = 1000 + ) -> List[Dict[str, str]]: + consumer = AIOKafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS, + group_id=None, # Set to None to avoid committing offsets + auto_offset_reset='earliest' + ) + + try: + await consumer.start() + + # Get partitions for the topic + partitions = consumer.partitions_for_topic(KAFKA_TOPIC) + if not partitions: + raise ValueError(f"Topic '{KAFKA_TOPIC}' not found") + + # Create TopicPartition objects + tps = [TopicPartition(KAFKA_TOPIC, p) for p in partitions] + + # Find offsets for the given timestamp + offsets = await consumer.offsets_for_times({tp: timestamp_ms for tp in tps}) + print(offsets) + # Seek to the correct offset for each partition + for tp, offset_and_timestamp in offsets.items(): + if offset_and_timestamp is None: + # If no offset found for the timestamp, seek to the end + consumer.seek_to_end(tp) + else: + print(tp) + print(offset_and_timestamp.offset) + consumer.seek(tp, offset_and_timestamp.offset) + + # Collect filtered messages + missed_messages = [] + while True: + try: + message = await asyncio.wait_for(consumer.getone(), timeout=1.0) + if message.key and message.key.decode() in key_filter: + missed_messages.append({ + "key": message.key.decode(), + "value": json.loads(message.value.decode()), + "msg_type": "MISSED" + }) + except asyncio.TimeoutError: + # No more messages within the timeout period + break + + return missed_messages + + finally: + await consumer.stop() diff --git a/topos/services/messages/missed_message_service.py b/topos/services/messages/missed_message_service.py new file mode 100644 index 0000000..85f07a5 --- /dev/null +++ b/topos/services/messages/missed_message_service.py @@ -0,0 +1,22 @@ +from chat_server.managers.missed_message_manager import MissedMessageManager +from chat_server.services.group_management_service import GroupManagementService +from chat_server.utils.utils import sqlite_timestamp_to_ms + +KAFKA_TOPIC = 'chat_topic' + +class MissedMessageService: + def __init__(self) -> None: + self.missed_message_manager = MissedMessageManager() + pass + # houskeeping if required + # if you need to inject the group management service here it could be an option ?? + + async def get_missed_messages(self,user_id :str ,group_management_service :GroupManagementService): + last_seen = group_management_service.get_user_last_seen_online(user_id=user_id) + if(last_seen): + users_groups = group_management_service.get_user_groups(user_id=user_id) + group_ids = [group["group_id"] for group in users_groups] + # get the last timestamp msg processed by the user + return await self.missed_message_manager.get_filtered_missed_messages(key_filter=group_ids,timestamp_ms=sqlite_timestamp_to_ms(last_seen)) + else: + return [] \ No newline at end of file diff --git a/topos/utilities/utils.py b/topos/utilities/utils.py index e9df92b..3df6870 100644 --- a/topos/utilities/utils.py +++ b/topos/utilities/utils.py @@ -3,6 +3,9 @@ import os import shutil +from datetime import datetime +import string + def get_python_command(): if shutil.which("python"): return "python" @@ -70,3 +73,23 @@ def generate_hex_code(n_digits): def generate_deci_code(n_digits): return ''.join(random.choice('0123456789') for _ in range(n_digits)) + +def generate_group_name() -> str: + return 'GRP-'.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + +def sqlite_timestamp_to_ms(sqlite_timestamp: str) -> int: + """ + Convert a SQLite timestamp string to milliseconds since Unix epoch. + + :param sqlite_timestamp: A timestamp string in the format "YYYY-MM-DD HH:MM:SS" + :return: Milliseconds since Unix epoch + """ + try: + # Parse the SQLite timestamp string + dt = datetime.strptime(sqlite_timestamp, "%Y-%m-%d %H:%M:%S") + + # Convert to milliseconds since Unix epoch + return int(dt.timestamp() * 1000) + except ValueError as e: + print(f"Error parsing timestamp: {e}") + return None \ No newline at end of file From db7cf1b5c42a13f120b70a9505b13e4672ae1e24 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:22:53 -0600 Subject: [PATCH 2/6] migrate sql -> posgres --- topos/chat_api/server.py | 25 +- .../messages/group_management_service.py | 29 +- topos/services/messages/group_manager.py | 117 +++++++++ .../services/messages/group_manager_sqlite.py | 247 ------------------ .../messages/missed_message_service.py | 6 +- 5 files changed, 150 insertions(+), 274 deletions(-) create mode 100644 topos/services/messages/group_manager.py delete mode 100644 topos/services/messages/group_manager_sqlite.py diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index b740501..58622d7 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -1,14 +1,15 @@ import asyncio import datetime import json +import os + from typing import Dict, List from fastapi import FastAPI, WebSocket, WebSocketDisconnect from aiokafka import AIOKafkaProducer, AIOKafkaConsumer from fastapi.concurrency import asynccontextmanager -from services.group_management_service import GroupManagementService -from services.missed_message_service import MissedMessageService -from services.startup.initialize_database import init_sqlite_database,ensure_file_exists -from utils.utils import generate_deci_code, generate_group_name +from topos.services.messages.group_management_service import GroupManagementService +from topos.services.messages.missed_message_service import MissedMessageService +from topos.utilities.utils import generate_deci_code, generate_group_name from pydantic import BaseModel # MissedMessageRequest model // subject to change class MissedMessagesRequest(BaseModel): @@ -73,15 +74,22 @@ async def broadcast(self, from_user_id:str,message, group_id:str):# print(f"Sending message: {message}") await connection.send_json(message) print("next connection") - -group_management_service = GroupManagementService() +db_config = { + "dbname": os.getenv("POSTGRES_DB"), + "user": os.getenv("POSTGRES_USER"), + "password": os.getenv("POSTGRES_PASSWORD"), + "host": os.getenv("POSTGRES_HOST"), + "port": os.getenv("POSTGRES_PORT") + } + +group_management_service = GroupManagementService(db_params=db_config) manager = ConnectionManager() producer = None consumer = None - + async def consume_messages(): async for msg in consumer: # print(msg.offset) @@ -96,8 +104,7 @@ async def lifespan(app: FastAPI): # Kafka producer global producer global consumer - ensure_file_exists("../db/user.db") - init_sqlite_database("../db/user.db") + producer = AIOKafkaProducer(bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS) # Kafka consumer diff --git a/topos/services/messages/group_management_service.py b/topos/services/messages/group_management_service.py index 77fdda8..0faa67a 100644 --- a/topos/services/messages/group_management_service.py +++ b/topos/services/messages/group_management_service.py @@ -1,21 +1,21 @@ from typing import List, Optional -from managers.group_manager_sqlite import GroupManagerSQLite +from topos.services.messages.group_manager import GroupManagerPostgres + class GroupManagementService: - def __init__(self) -> None: - self.group_manager = GroupManagerSQLite() # this implementation can be swapped for oother implementations out based on env var, use if statements - # any other house keeping can be done here too + def __init__(self, db_params: dict) -> None: + self.group_manager = GroupManagerPostgres(db_params) def create_group(self, group_name: str) -> str: return self.group_manager.create_group(group_name=group_name) - def create_user(self, user_id:str,username: str) -> str: - return self.group_manager.create_user(user_id,username) + def create_user(self, user_id: str, username: str) -> str: + return self.group_manager.create_user(user_id, username) def add_user_to_group(self, user_id: str, group_id: str) -> bool: - return self.group_manager.add_user_to_group(user_id=user_id,group_id=group_id) + return self.group_manager.add_user_to_group(user_id=user_id, group_id=group_id) def remove_user_from_group(self, user_id: str, group_id: str) -> bool: - return self.group_manager.remove_user_from_group(user_id=user_id,group_id=group_id) + return self.group_manager.remove_user_from_group(user_id=user_id, group_id=group_id) def get_user_groups(self, user_id: str) -> List[dict]: return self.group_manager.get_user_groups(user_id) @@ -31,19 +31,18 @@ def get_user_by_id(self, user_id: str) -> Optional[dict]: def get_group_by_name(self, group_name: str) -> Optional[dict]: return self.group_manager.get_group_by_name(group_name) - + def get_user_by_username(self, username: str) -> Optional[dict]: - return self.get_user_by_username(username) + return self.group_manager.get_user_by_username(username) def delete_group(self, group_id: str) -> bool: return self.group_manager.delete_group(group_id) def delete_user(self, user_id: str) -> bool: return self.group_manager.delete_user(user_id) - - def set_user_last_seen_online(self,user_id:str)-> bool: + + def set_user_last_seen_online(self, user_id: str) -> bool: return self.group_manager.set_user_last_seen_online(user_id) - - def get_user_last_seen_online(self,user_id:str)-> bool: - return self.group_manager.get_user_last_seen_online(user_id) + def get_user_last_seen_online(self, user_id: str) -> Optional[str]: + return self.group_manager.get_user_last_seen_online(user_id) \ No newline at end of file diff --git a/topos/services/messages/group_manager.py b/topos/services/messages/group_manager.py new file mode 100644 index 0000000..3f4d35c --- /dev/null +++ b/topos/services/messages/group_manager.py @@ -0,0 +1,117 @@ +import psycopg2 +from psycopg2.extras import DictCursor +from datetime import datetime +from typing import List, Optional, Dict +from topos.utilities.utils import generate_deci_code + +class GroupManagerPostgres: + def __init__(self, db_params: Dict[str, str]): + self.db_params = db_params + + def _get_connection(self): + return psycopg2.connect(**self.db_params) + + def create_group(self, group_name: str) -> str: + group_id = generate_deci_code(6) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('INSERT INTO groups (group_id, group_name) VALUES (%s, %s)', (group_id, group_name)) + return group_id + + def create_user(self, user_id: str, username: str) -> str: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('INSERT INTO users (user_id, username) VALUES (%s, %s)', (user_id, username)) + return user_id + + def add_user_to_group(self, user_id: str, group_id: str) -> bool: + try: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('INSERT INTO user_groups (user_id, group_id) VALUES (%s, %s)', (user_id, group_id)) + return True + except psycopg2.IntegrityError: + return False + + def remove_user_from_group(self, user_id: str, group_id: str) -> bool: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('DELETE FROM user_groups WHERE user_id = %s AND group_id = %s', (user_id, group_id)) + return cur.rowcount > 0 + + def get_user_groups(self, user_id: str) -> List[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute(''' + SELECT g.group_id, g.group_name + FROM groups g + JOIN user_groups ug ON g.group_id = ug.group_id + WHERE ug.user_id = %s + ''', (user_id,)) + return [dict(row) for row in cur.fetchall()] + + def get_group_users(self, group_id: str) -> List[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute(''' + SELECT u.user_id, u.username + FROM users u + JOIN user_groups ug ON u.user_id = ug.user_id + WHERE ug.group_id = %s + ''', (group_id,)) + return [dict(row) for row in cur.fetchall()] + + def get_group_by_id(self, group_id: str) -> Optional[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute('SELECT group_id, group_name FROM groups WHERE group_id = %s', (group_id,)) + result = cur.fetchone() + return dict(result) if result else None + + def get_user_by_id(self, user_id: str) -> Optional[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute('SELECT user_id, username FROM users WHERE user_id = %s', (user_id,)) + result = cur.fetchone() + return dict(result) if result else None + + def get_group_by_name(self, group_name: str) -> Optional[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute('SELECT group_id, group_name FROM groups WHERE group_name = %s', (group_name,)) + result = cur.fetchone() + return dict(result) if result else None + + def get_user_by_username(self, username: str) -> Optional[dict]: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute('SELECT user_id, username FROM users WHERE username = %s', (username,)) + result = cur.fetchone() + return dict(result) if result else None + + def delete_group(self, group_id: str) -> bool: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('DELETE FROM user_groups WHERE group_id = %s', (group_id,)) + cur.execute('DELETE FROM groups WHERE group_id = %s', (group_id,)) + return cur.rowcount > 0 + + def delete_user(self, user_id: str) -> bool: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('DELETE FROM user_groups WHERE user_id = %s', (user_id,)) + cur.execute('DELETE FROM users WHERE user_id = %s', (user_id,)) + return cur.rowcount > 0 + + def get_user_last_seen_online(self, user_id: str) -> Optional[str]: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('SELECT last_seen_online FROM users WHERE user_id = %s', (user_id,)) + result = cur.fetchone() + return result[0].isoformat() if result else None + + def set_user_last_seen_online(self, user_id: str) -> bool: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute('UPDATE users SET last_seen_online = %s WHERE user_id = %s', (datetime.now(), user_id)) + return cur.rowcount > 0 \ No newline at end of file diff --git a/topos/services/messages/group_manager_sqlite.py b/topos/services/messages/group_manager_sqlite.py deleted file mode 100644 index 99023ca..0000000 --- a/topos/services/messages/group_manager_sqlite.py +++ /dev/null @@ -1,247 +0,0 @@ -from datetime import datetime -import sqlite3 -import uuid -from typing import List, Optional, Dict -from utils.utils import generate_deci_code -class GroupManagerSQLite: - def __init__(self, db_file: str = '../db/user.db'): - self.db_file = db_file - - # Initialize empty caches - self.groups_cache: Dict[str, Dict] = {} # group_id -> group_info - self.users_cache: Dict[str, Dict] = {} # user_id -> user_info - self.user_groups_cache: Dict[str, List[str]] = {} # user_id -> list of group_ids - self.group_users_cache: Dict[str, List[str]] = {} # group_id -> list of user_ids - - def _get_group_from_db(self, group_id: str) -> Optional[Dict]: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT group_id, group_name FROM groups WHERE group_id = ?', (group_id,)) - result = cursor.fetchone() - print(result) - if result: - return {"group_id": result[0], "group_name": result[1]} - return None - - def _get_user_from_db(self, user_id: str) -> Optional[Dict]: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT user_id, username FROM users WHERE user_id = ?', (user_id,)) - result = cursor.fetchone() - if result: - return {"user_id": result[0], "username": result[1]} - return None - - def _get_user_groups_from_db(self, user_id: str) -> List[str]: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT group_id FROM user_groups WHERE user_id = ?', (user_id,)) - return [row[0] for row in cursor.fetchall()] - - def _get_group_users_from_db(self, group_id: str) -> List[str]: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT user_id FROM user_groups WHERE group_id = ?', (group_id,)) - return [row[0] for row in cursor.fetchall()] - - def create_group(self, group_name: str) -> str: - group_id = generate_deci_code(6) - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('INSERT INTO groups (group_id, group_name) VALUES (?, ?)', (group_id, group_name)) - conn.commit() - - # Update cache - self.groups_cache[group_id] = {"group_id": group_id, "group_name": group_name} - self.group_users_cache[group_id] = [] - - return group_id - - def create_user(self, user_id:str,username: str,) -> str: - - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('INSERT INTO users (user_id, username) VALUES (?, ?)', (user_id, username)) - conn.commit() - - # Update cache - self.users_cache[user_id] = {"user_id": user_id, "username": username} - self.user_groups_cache[user_id] = [] - - return user_id - - def add_user_to_group(self, user_id: str, group_id: str) -> bool: - try: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('INSERT INTO user_groups (user_id, group_id) VALUES (?, ?)', (user_id, group_id)) - conn.commit() - - # Update cache if the entries exist - if user_id in self.user_groups_cache: - self.user_groups_cache[user_id].append(group_id) - if group_id in self.group_users_cache: - self.group_users_cache[group_id].append(user_id) - - return True - except sqlite3.IntegrityError: - return False # User already in group or user/group doesn't exist - - def remove_user_from_group(self, user_id: str, group_id: str) -> bool: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM user_groups WHERE user_id = ? AND group_id = ?', (user_id, group_id)) - conn.commit() - if cursor.rowcount > 0: - # Update cache if the entries exist - if user_id in self.user_groups_cache: - self.user_groups_cache[user_id].remove(group_id) - if group_id in self.group_users_cache: - self.group_users_cache[group_id].remove(user_id) - return True - return False - - def get_user_groups(self, user_id: str) -> List[dict]: - if user_id not in self.user_groups_cache: - self.user_groups_cache[user_id] = self._get_user_groups_from_db(user_id) - - return [self.get_group_by_id(group_id) for group_id in self.user_groups_cache[user_id]] - - def get_group_users(self, group_id: str) -> List[dict]: - if group_id not in self.group_users_cache: - self.group_users_cache[group_id] = self._get_group_users_from_db(group_id) - - return [self.get_user_by_id(user_id) for user_id in self.group_users_cache[group_id]] - - def get_group_by_id(self, group_id: str) -> Optional[dict]: - if group_id not in self.groups_cache: - group = self._get_group_from_db(group_id) - if group: - self.groups_cache[group_id] = group - else: - return None - return self.groups_cache[group_id] - - def get_user_by_id(self, user_id: str) -> Optional[dict]: - if user_id not in self.users_cache: - user = self._get_user_from_db(user_id) - if user: - self.users_cache[user_id] = user - else: - return None - return self.users_cache[user_id] - - def get_group_by_name(self, group_name: str) -> Optional[dict]: - # This operation requires a full DB scan if not in cache - for group in self.groups_cache.values(): - if group['group_name'] == group_name: - return group - - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT group_id, group_name FROM groups WHERE group_name = ?', (group_name,)) - result = cursor.fetchone() - if result: - group = {"group_id": result[0], "group_name": result[1]} - self.groups_cache[group['group_id']] = group - return group - return None - - def get_user_by_username(self, username: str) -> Optional[dict]: - # This operation requires a full DB scan if not in cache - for user in self.users_cache.values(): - if user['username'] == username: - return user - - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('SELECT user_id, username FROM users WHERE username = ?', (username,)) - result = cursor.fetchone() - if result: - user = {"user_id": result[0], "username": result[1]} - self.users_cache[user['user_id']] = user - return user - return None - - def delete_group(self, group_id: str) -> bool: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM user_groups WHERE group_id = ?', (group_id,)) - cursor.execute('DELETE FROM groups WHERE group_id = ?', (group_id,)) - conn.commit() - if cursor.rowcount > 0: - # Update cache - self.groups_cache.pop(group_id, None) - self.group_users_cache.pop(group_id, None) - for user_groups in self.user_groups_cache.values(): - if group_id in user_groups: - user_groups.remove(group_id) - return True - return False - - def delete_user(self, user_id: str) -> bool: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM user_groups WHERE user_id = ?', (user_id,)) - cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,)) - conn.commit() - if cursor.rowcount > 0: - # Update cache - self.users_cache.pop(user_id, None) - self.user_groups_cache.pop(user_id, None) - for group_users in self.group_users_cache.values(): - if user_id in group_users: - group_users.remove(user_id) - return True - return False - def get_user_last_seen_online(self, user_id: str) -> str: - """ - Get the last_seen_online timestamp for a given user_id. - - :param user_id: The ID of the user - :return: The last seen timestamp as a string, or None if user not found - """ - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - try: - cursor.execute(''' - SELECT last_seen_online - FROM users - WHERE user_id = ? - ''', (user_id,)) - - result = cursor.fetchone() - if result: - return result[0] - else: - print(f"User with ID {user_id} not found.") - return None - except sqlite3.Error as e: - print(f"An error occurred: {e}") - return None - - def set_user_last_seen_online(self, user_id: str) -> bool: - """ - Set the last_seen_online timestamp for a given user_id to the current time. - - :param user_id: The ID of the user - :return: True if successful, False if user not found - """ - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - try: - cursor.execute(''' - UPDATE users - SET last_seen_online = ? - WHERE user_id = ? - ''', (datetime.now().replace(microsecond=0), user_id)) - - if cursor.rowcount == 0: - print(f"User with ID {user_id} not found.") - return False - - return True - except sqlite3.Error as e: - print(f"An error occurred: {e}") - return False - diff --git a/topos/services/messages/missed_message_service.py b/topos/services/messages/missed_message_service.py index 85f07a5..1e1ae05 100644 --- a/topos/services/messages/missed_message_service.py +++ b/topos/services/messages/missed_message_service.py @@ -1,6 +1,6 @@ -from chat_server.managers.missed_message_manager import MissedMessageManager -from chat_server.services.group_management_service import GroupManagementService -from chat_server.utils.utils import sqlite_timestamp_to_ms +from topos.services.messages.missed_message_manager import MissedMessageManager +from topos.services.messages.group_management_service import GroupManagementService +from topos.utilities.utils import sqlite_timestamp_to_ms KAFKA_TOPIC = 'chat_topic' From 0dd59588deca91e3a5a0656461493121c91995bf Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:59:32 -0600 Subject: [PATCH 3/6] updates to chat api --- flake.nix | 22 ++++++++++++++++++++++ topos/api/__init__.py | 0 topos/api/api.py | 21 ++++++++++++++++++++- topos/chat_api/__init__.py | 0 topos/chat_api/api.py | 12 +++++++----- topos/chat_api/server.py | 8 ++++---- topos/cli.py | 2 +- topos/downloaders/__init__.py | 0 8 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 topos/api/__init__.py create mode 100644 topos/chat_api/__init__.py create mode 100644 topos/downloaders/__init__.py diff --git a/flake.nix b/flake.nix index f4f5268..772e950 100644 --- a/flake.nix +++ b/flake.nix @@ -172,6 +172,28 @@ emo_27_label VARCHAR ); + CREATE TABLE IF NOT EXISTS groups ( + group_id TEXT PRIMARY KEY, + group_name TEXT NOT NULL UNIQUE + ); + + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + last_seen_online TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS user_groups ( + user_id TEXT, + group_id TEXT, + FOREIGN KEY (user_id) REFERENCES users (user_id), + FOREIGN KEY (group_id) REFERENCES groups (group_id), + PRIMARY KEY (user_id, group_id) + ); + + CREATE INDEX IF NOT EXISTS idx_user_groups_user_id ON user_groups (user_id); + CREATE INDEX IF NOT EXISTS idx_user_groups_group_id ON user_groups (group_id); + GRANT ALL PRIVILEGES ON DATABASE ${envVars.POSTGRES_DB} TO ${envVars.POSTGRES_USER}; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO ${envVars.POSTGRES_USER}; GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO ${envVars.POSTGRES_USER}; diff --git a/topos/api/__init__.py b/topos/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/topos/api/api.py b/topos/api/api.py index 6378c95..48453db 100644 --- a/topos/api/api.py +++ b/topos/api/api.py @@ -40,12 +40,31 @@ """ +from multiprocessing import Process +import uvicorn -def start_local_api(): +def start_topos_api(): """Function to start the API in local mode.""" print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://0.0.0.0:13341/docs\033[0m") uvicorn.run(app, host="0.0.0.0", port=13341) +def start_kafka_api(): + from ..chat_api.api import start_messenger_server + start_messenger_server() + +def start_local_api(): + process1 = Process(target=start_topos_api) + process2 = Process(target=start_kafka_api) + process1.start() + process2.start() + process1.join() + process2.join() + +# def start_local_api(): +# """Function to start the API in local mode.""" +# print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://0.0.0.0:13341/docs\033[0m") +# uvicorn.run(app, host="0.0.0.0", port=13341) + def start_web_api(): """Function to start the API in web mode with SSL.""" diff --git a/topos/chat_api/__init__.py b/topos/chat_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/topos/chat_api/api.py b/topos/chat_api/api.py index 786a238..5f71d93 100644 --- a/topos/chat_api/api.py +++ b/topos/chat_api/api.py @@ -1,11 +1,13 @@ -import subprocess +from fastapi import FastAPI +from ..config import setup_config, get_ssl_certificates +import uvicorn +from .server import app as chat_app -def start_chat(): +def start_messenger_server(): """Function to start the API in local mode.""" - # print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://127.0.0.1:13394/docs\033[0m") - # subprocess.run(["python", "topos/chat_api/chat_server.py"]) # A barebones chat server - subprocess.run(["uvicorn", "topos.chat_api.server:app", "--host", "0.0.0.0", "--port", "13394", "--workers", "1"]) + print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://127.0.0.1:13394/docs\033[0m") + uvicorn.run(chat_app, host="127.0.0.1", port=13394) # start through zrok # uvicorn main:app --host 127.0.0.1 --port 13394 & zrok expose http://localhost:13394 diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index 58622d7..e6bf1a2 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -3,8 +3,8 @@ import json import os -from typing import Dict, List from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from typing import Dict, List from aiokafka import AIOKafkaProducer, AIOKafkaConsumer from fastapi.concurrency import asynccontextmanager from topos.services.messages.group_management_service import GroupManagementService @@ -220,9 +220,9 @@ async def get_missed_messages(request: MissedMessagesRequest): missed_message_service = MissedMessageService() return await missed_message_service.get_missed_messages(user_id=request.user_id,group_management_service=group_management_service) -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=13394) +# if __name__ == "__main__": +# import uvicorn +# uvicorn.run(app, host="0.0.0.0", port=13394) """ Message JSON Schema diff --git a/topos/cli.py b/topos/cli.py index e4b8e9b..27986fc 100644 --- a/topos/cli.py +++ b/topos/cli.py @@ -34,7 +34,7 @@ def main(): """ # import chat_api from .chat_api import api - api.start_chat() + api.start_messenger_server() if args.command == 'zrok': """ diff --git a/topos/downloaders/__init__.py b/topos/downloaders/__init__.py new file mode 100644 index 0000000..e69de29 From 2c742541f34cd43d43f969346267d04454b29223 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Mon, 4 Nov 2024 17:07:33 -0600 Subject: [PATCH 4/6] + aiokafka import --- pyproject.toml | 3 ++- topos/chat_api/server.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 595d38b..afdd663 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,13 +39,14 @@ python-multipart = "^0.0.9" pytest-asyncio = "^0.23.7" textblob = "^0.18.0.post0" pystray = "0.19.5" - +aiokafka = "^0.11.0" supabase = "^2.6.0" psycopg2-binary = "^2.9.9" en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"} # en-core-web-lg = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl"} # en-core-web-md = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl"} # en-core-web-trf = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl"} + [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" pytest-asyncio = "^0.23.2" diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index e6bf1a2..ca06a3d 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -20,7 +20,6 @@ class MissedMessagesRequest(BaseModel): KAFKA_BOOTSTRAP_SERVERS = 'localhost:9092' KAFKA_TOPIC = 'chat_topic' - # WebSocket connection manager class ConnectionManager: def __init__(self): From 567769ff1b72b740dfc8fb514c26852951f856e8 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:34:48 -0600 Subject: [PATCH 5/6] fix: Fixing the postgres init sequence (WIP) --- .env_dev | 2 +- .env_template | 15 --- default.nix | 27 +++++- flake.nix | 35 ++----- justfile | 1 + tests/database/postgres_init_test.py | 97 +++++++++++++++++++ topos/chat_api/server.py | 2 + .../database/conversation_cache_manager.py | 2 +- 8 files changed, 136 insertions(+), 45 deletions(-) delete mode 100644 .env_template create mode 100644 tests/database/postgres_init_test.py diff --git a/.env_dev b/.env_dev index 9aa8394..e919969 100644 --- a/.env_dev +++ b/.env_dev @@ -8,7 +8,7 @@ OPEN_AI_API_KEY="sk-openai.com123" ONE_API_API_KEY="sk-oneapi.local123" SUPABASE_URL= SUPABASE_KEY= -POSTGRES_DB=test_topos_db +POSTGRES_DB=test_topos_db_1 POSTGRES_USER=jonny POSTGRES_PASSWORD=1234589034 POSTGRES_HOST=127.0.0.1 diff --git a/.env_template b/.env_template deleted file mode 100644 index 6caf27a..0000000 --- a/.env_template +++ /dev/null @@ -1,15 +0,0 @@ -NEO4J_URI="bolt://localhost:7687" -NEO4J_USER="neo4j" -NEO4J_PASSWORD="password" -NEO4J_TEST_DATABASE="neo4j" -NEO4J_SHOWROOM_DATABASE="neo4j" -JWT_SECRET="terces_tj" -OPEN_AI_API_KEY="sk-openai.com123" -ONE_API_API_KEY="sk-oneapi.local123" -SUPABASE_URL= -SUPABASE_KEY= -POSTGRES_DB=test_topos_db -POSTGRES_USER=username -POSTGRES_PASSWORD=your_password_here -POSTGRES_HOST=127.0.0.1 -POSTGRES_PORT=5432 diff --git a/default.nix b/default.nix index 123ff1f..1b5788f 100644 --- a/default.nix +++ b/default.nix @@ -41,8 +41,10 @@ in pkgs.mkShell { sleep 2 # Set up the test database, role, and tables + echo "Setting up the test database..." + # psql -U $POSTGRES_USER -c "CREATE DATABASE $POSTGRES_DB;" || echo "Database $POSTGRES_DB already exists." + psql -d $POSTGRES_DB < Date: Wed, 6 Nov 2024 08:19:55 -0600 Subject: [PATCH 6/6] feat: Added posgresql table init on python side --- flake.nix | 1 + topos/api/api.py | 26 ++++++++++--- topos/chat_api/server.py | 21 ++++++++--- topos/services/messages/group_manager.py | 48 +++++++++++++++++++++++- 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/flake.nix b/flake.nix index 2e73e02..4c42525 100644 --- a/flake.nix +++ b/flake.nix @@ -129,6 +129,7 @@ listen_addresses = "127.0.0.1"; # dataDir = "${dataDirBase}/pg"; initialDatabases = [ + { name = "${envVars.POSTGRES_DB}"; } ]; initialScript = { diff --git a/topos/api/api.py b/topos/api/api.py index 48453db..6c5a6d5 100644 --- a/topos/api/api.py +++ b/topos/api/api.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from ..config import setup_config, get_ssl_certificates import uvicorn +import signal # Create the FastAPI application instance app = FastAPI() @@ -52,7 +53,12 @@ def start_kafka_api(): from ..chat_api.api import start_messenger_server start_messenger_server() +# Global references to processes for cleanup +process1 = None +process2 = None + def start_local_api(): + global process1, process2 process1 = Process(target=start_topos_api) process2 = Process(target=start_kafka_api) process1.start() @@ -60,11 +66,21 @@ def start_local_api(): process1.join() process2.join() -# def start_local_api(): -# """Function to start the API in local mode.""" -# print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://0.0.0.0:13341/docs\033[0m") -# uvicorn.run(app, host="0.0.0.0", port=13341) - +def handle_cleanup(signum, frame): + """Cleanup function to terminate processes on exit.""" + print("Cleaning up processes...") + if process1 is not None: + process1.terminate() + process1.join() + if process2 is not None: + process2.terminate() + process2.join() + print("Processes terminated.") + exit(0) # Exit the program + +# Register the signal handler for cleanup +signal.signal(signal.SIGINT, handle_cleanup) +signal.signal(signal.SIGTERM, handle_cleanup) def start_web_api(): """Function to start the API in web mode with SSL.""" diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index c306c8e..ca89f99 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -1,6 +1,7 @@ import asyncio import datetime import json +import signal import os from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -118,12 +119,22 @@ async def lifespan(app: FastAPI): # https://stackoverflow.com/questions/46890646/asyncio-weirdness-of-task-exception-was-never-retrieved # we need to keep a reference of this task alive else it will stop the consume task, there has to be a live refference for this to work consume_task = asyncio.create_task(consume_messages()) - yield - # Clean up the ML models and release the resources - consume_task.cancel() - await producer.stop() - await consumer.stop() + def shutdown(signal, loop): + print("Received exit signal", signal) + consume_task.cancel() + loop.stop() + + # Add signal handler for graceful shutdown on Ctrl+C + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, shutdown, signal.SIGINT, loop) + + try: + yield + finally: + consume_task.cancel() + await producer.stop() + await consumer.stop() # FastAPI app app = FastAPI(lifespan=lifespan) diff --git a/topos/services/messages/group_manager.py b/topos/services/messages/group_manager.py index 3f4d35c..2ed8744 100644 --- a/topos/services/messages/group_manager.py +++ b/topos/services/messages/group_manager.py @@ -3,14 +3,60 @@ from datetime import datetime from typing import List, Optional, Dict from topos.utilities.utils import generate_deci_code +import os +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() class GroupManagerPostgres: def __init__(self, db_params: Dict[str, str]): self.db_params = db_params - + self._setup_tables() + def _get_connection(self): return psycopg2.connect(**self.db_params) + def _setup_tables(self): + """Ensures necessary tables exist with required permissions.""" + + setup_sql_commands = [ + """ + CREATE TABLE IF NOT EXISTS groups ( + group_id TEXT PRIMARY KEY, + group_name TEXT NOT NULL UNIQUE + ); + """, + """ + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + last_seen_online TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """, + """ + CREATE TABLE IF NOT EXISTS user_groups ( + user_id TEXT, + group_id TEXT, + FOREIGN KEY (user_id) REFERENCES users (user_id), + FOREIGN KEY (group_id) REFERENCES groups (group_id), + PRIMARY KEY (user_id, group_id) + ); + """, + "CREATE INDEX IF NOT EXISTS idx_user_groups_user_id ON user_groups (user_id);", + "CREATE INDEX IF NOT EXISTS idx_user_groups_group_id ON user_groups (group_id);", + f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {os.getenv('POSTGRES_USER')};", + f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {os.getenv('POSTGRES_USER')};", + f"GRANT pg_read_all_data TO {os.getenv('POSTGRES_USER')};", + f"GRANT pg_write_all_data TO {os.getenv('POSTGRES_USER')};" + ] + + with self._get_connection() as conn: + with conn.cursor() as cur: + for command in setup_sql_commands: + cur.execute(command) + conn.commit() + def create_group(self, group_name: str) -> str: group_id = generate_deci_code(6) with self._get_connection() as conn: