Skip to content

Commit

Permalink
[ added postgres to database - integration test + instructions ]
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-aftersource committed Aug 9, 2024
1 parent 4791f74 commit ce63826
Show file tree
Hide file tree
Showing 5 changed files with 508 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .env_dev
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@ NEO4J_SHOWROOM_DATABASE="neo4j"
JWT_SECRET="terces_tj"
ONE_API_API_KEY="sk-etc"
SUPABASE_URL=
SUPABASE_KEY=
SUPABASE_KEY=
POSTGRES_DB=test_topos_db
POSTGRES_USER=your_username
POSTGRES_PASSWORD=your_password
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
83 changes: 82 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ textblob = "^0.18.0.post0"
tk = "0.1.0"

supabase = "^2.6.0"
psycopg2-binary = "^2.9.9"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
Expand Down
142 changes: 142 additions & 0 deletions topos/services/database/postgres_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor, Json
from typing import List, Dict, Any
from topos.services.database.database_interface import DatabaseInterface

class PostgresDatabase(DatabaseInterface):
def __init__(self, dbname: str, user: str, password: str, host: str = 'localhost', port: str = '5432'):
super().__init__()
self.pool = psycopg2.pool.SimpleConnectionPool(1, 20,
dbname=dbname, user=user, password=password, host=host, port=port
)
print(f"\t[ PostgresDatabase init ]")

def __del__(self):
if hasattr(self, 'pool'):
self.pool.closeall()

def _get_conn(self):
return self.pool.getconn()

def _put_conn(self, conn):
self.pool.putconn(conn)

def add_entity(self, entity_id: str, entity_label: str, properties: Dict[str, Any]) -> None:
query = """
INSERT INTO entities (id, label, properties)
VALUES (%s, %s, %s)
ON CONFLICT (id) DO UPDATE
SET label = EXCLUDED.label, properties = EXCLUDED.properties
"""
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(query, (entity_id, entity_label, Json(properties)))
conn.commit()
finally:
self._put_conn(conn)

def add_relation(self, source_id: str, relation_type: str, target_id: str, properties: Dict[str, Any]) -> None:
query = """
INSERT INTO relations (source_id, relation_type, target_id, properties)
VALUES (%s, %s, %s, %s)
ON CONFLICT (source_id, relation_type, target_id) DO UPDATE
SET properties = EXCLUDED.properties
"""
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(query, (source_id, relation_type, target_id, Json(properties)))
conn.commit()
finally:
self._put_conn(conn)

def get_messages_by_user(self, user_id: str, relation_type: str) -> List[Dict[str, Any]]:
query = """
SELECT e.id as message_id, e.properties->>'content' as message, e.properties->>'timestamp' as timestamp
FROM relations r
JOIN entities e ON r.target_id = e.id
WHERE r.source_id = %s AND r.relation_type = %s AND e.label = 'MESSAGE'
"""
conn = self._get_conn()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (user_id, relation_type))
return cur.fetchall()
finally:
self._put_conn(conn)

def get_messages_by_session(self, session_id: str, relation_type: str) -> List[Dict[str, Any]]:
query = """
SELECT e.id as message_id, e.properties->>'content' as message, e.properties->>'timestamp' as timestamp
FROM relations r
JOIN entities e ON r.target_id = e.id
WHERE r.source_id = %s AND r.relation_type = %s AND e.label = 'MESSAGE'
"""
conn = self._get_conn()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (session_id, relation_type))
return cur.fetchall()
finally:
self._put_conn(conn)

def get_users_by_session(self, session_id: str, relation_type: str) -> List[Dict[str, Any]]:
query = """
SELECT r.source_id as user_id
FROM relations r
WHERE r.target_id = %s AND r.relation_type = %s
"""
conn = self._get_conn()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (session_id, relation_type))
return cur.fetchall()
finally:
self._put_conn(conn)

def get_sessions_by_user(self, user_id: str, relation_type: str) -> List[Dict[str, Any]]:
query = """
SELECT r.target_id as session_id
FROM relations r
JOIN entities e ON r.target_id = e.id
WHERE r.source_id = %s AND r.relation_type = %s AND e.label = 'SESSION'
"""
conn = self._get_conn()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (user_id, relation_type))
return cur.fetchall()
finally:
self._put_conn(conn)

def get_message_by_id(self, message_id: str) -> Dict[str, Any]:
query = """
SELECT properties->>'content' as message, properties->>'timestamp' as timestamp
FROM entities
WHERE id = %s AND label = 'MESSAGE'
"""
conn = self._get_conn()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, (message_id,))
result = cur.fetchone()
return result if result else {}
finally:
self._put_conn(conn)

def value_exists(self, label: str, key: str, value: str) -> bool:
query = """
SELECT 1
FROM entities
WHERE label = %s AND properties->>%s = %s
LIMIT 1
"""
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(query, (label, key, value))
return bool(cur.fetchone())
finally:
self._put_conn(conn)
Loading

0 comments on commit ce63826

Please sign in to comment.