diff --git a/flake.nix b/flake.nix index 2e73e02..4c42525 100644 --- a/flake.nix +++ b/flake.nix @@ -129,6 +129,7 @@ listen_addresses = "127.0.0.1"; # dataDir = "${dataDirBase}/pg"; initialDatabases = [ + { name = "${envVars.POSTGRES_DB}"; } ]; initialScript = { diff --git a/topos/api/api.py b/topos/api/api.py index 48453db..6c5a6d5 100644 --- a/topos/api/api.py +++ b/topos/api/api.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from ..config import setup_config, get_ssl_certificates import uvicorn +import signal # Create the FastAPI application instance app = FastAPI() @@ -52,7 +53,12 @@ def start_kafka_api(): from ..chat_api.api import start_messenger_server start_messenger_server() +# Global references to processes for cleanup +process1 = None +process2 = None + def start_local_api(): + global process1, process2 process1 = Process(target=start_topos_api) process2 = Process(target=start_kafka_api) process1.start() @@ -60,11 +66,21 @@ def start_local_api(): process1.join() process2.join() -# def start_local_api(): -# """Function to start the API in local mode.""" -# print("\033[92mINFO:\033[0m API docs available at: \033[1mhttp://0.0.0.0:13341/docs\033[0m") -# uvicorn.run(app, host="0.0.0.0", port=13341) - +def handle_cleanup(signum, frame): + """Cleanup function to terminate processes on exit.""" + print("Cleaning up processes...") + if process1 is not None: + process1.terminate() + process1.join() + if process2 is not None: + process2.terminate() + process2.join() + print("Processes terminated.") + exit(0) # Exit the program + +# Register the signal handler for cleanup +signal.signal(signal.SIGINT, handle_cleanup) +signal.signal(signal.SIGTERM, handle_cleanup) def start_web_api(): """Function to start the API in web mode with SSL.""" diff --git a/topos/chat_api/server.py b/topos/chat_api/server.py index c306c8e..ca89f99 100644 --- a/topos/chat_api/server.py +++ b/topos/chat_api/server.py @@ -1,6 +1,7 @@ import asyncio import datetime import json +import signal import os from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -118,12 +119,22 @@ async def lifespan(app: FastAPI): # https://stackoverflow.com/questions/46890646/asyncio-weirdness-of-task-exception-was-never-retrieved # we need to keep a reference of this task alive else it will stop the consume task, there has to be a live refference for this to work consume_task = asyncio.create_task(consume_messages()) - yield - # Clean up the ML models and release the resources - consume_task.cancel() - await producer.stop() - await consumer.stop() + def shutdown(signal, loop): + print("Received exit signal", signal) + consume_task.cancel() + loop.stop() + + # Add signal handler for graceful shutdown on Ctrl+C + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, shutdown, signal.SIGINT, loop) + + try: + yield + finally: + consume_task.cancel() + await producer.stop() + await consumer.stop() # FastAPI app app = FastAPI(lifespan=lifespan) diff --git a/topos/services/messages/group_manager.py b/topos/services/messages/group_manager.py index 3f4d35c..2ed8744 100644 --- a/topos/services/messages/group_manager.py +++ b/topos/services/messages/group_manager.py @@ -3,14 +3,60 @@ from datetime import datetime from typing import List, Optional, Dict from topos.utilities.utils import generate_deci_code +import os +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() class GroupManagerPostgres: def __init__(self, db_params: Dict[str, str]): self.db_params = db_params - + self._setup_tables() + def _get_connection(self): return psycopg2.connect(**self.db_params) + def _setup_tables(self): + """Ensures necessary tables exist with required permissions.""" + + setup_sql_commands = [ + """ + CREATE TABLE IF NOT EXISTS groups ( + group_id TEXT PRIMARY KEY, + group_name TEXT NOT NULL UNIQUE + ); + """, + """ + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + last_seen_online TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """, + """ + CREATE TABLE IF NOT EXISTS user_groups ( + user_id TEXT, + group_id TEXT, + FOREIGN KEY (user_id) REFERENCES users (user_id), + FOREIGN KEY (group_id) REFERENCES groups (group_id), + PRIMARY KEY (user_id, group_id) + ); + """, + "CREATE INDEX IF NOT EXISTS idx_user_groups_user_id ON user_groups (user_id);", + "CREATE INDEX IF NOT EXISTS idx_user_groups_group_id ON user_groups (group_id);", + f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {os.getenv('POSTGRES_USER')};", + f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {os.getenv('POSTGRES_USER')};", + f"GRANT pg_read_all_data TO {os.getenv('POSTGRES_USER')};", + f"GRANT pg_write_all_data TO {os.getenv('POSTGRES_USER')};" + ] + + with self._get_connection() as conn: + with conn.cursor() as cur: + for command in setup_sql_commands: + cur.execute(command) + conn.commit() + def create_group(self, group_name: str) -> str: group_id = generate_deci_code(6) with self._get_connection() as conn: