Skip to content

Commit

Permalink
feat(database): migrate from MySQL to PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwenty228 committed Dec 22, 2023
1 parent 16778b0 commit a6ba5b0
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 145 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ cython_debug/

/data
/test

*.session.sql
.vscode/settings.json
26 changes: 3 additions & 23 deletions cogs/Citizen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@



from utils.data_manager import QR_PATH
from settings import QR_PATH
from utils.embeds import user_embed, qr_confirmation_embed
from utils.views import Show_User_View, Confirmation_View
from utils.promptpay import PromptPay
Expand All @@ -16,16 +16,6 @@
class Citizen(commands.Cog):
def __init__(self, bot: 'Oppy') -> None:
self.bot = bot
set_phone = app_commands.ContextMenu(
name='set phone number',
callback=self.set_phone_from_message,
)
set_qr = app_commands.ContextMenu(
name='set qr code',
callback=self.set_qr_from_message,
)
self.bot.tree.add_command(set_phone)
self.bot.tree.add_command(set_qr)
self._routine = {}

async def cog_unload(self) -> None:
Expand Down Expand Up @@ -64,6 +54,7 @@ async def __show_user(self, interaction: discord.Interaction, user: discord.Memb
pass

all_users = await self.bot.database.get_user(guild_id=interaction.guild.id)

view = Show_User_View(interaction, self.bot, user, all_users, timeout=180.0, ephemeral=ephemeral)

if (u := await self.bot.database.get_user(user=user)):
Expand Down Expand Up @@ -109,18 +100,6 @@ async def __set_qr(self, interaction: discord.Interaction, qr: discord.Attachmen
else:
await interaction.edit_original_response(content="Cancelled.", view=None, embed=None)

# message menu commands ====================================================================================================

async def set_phone_from_message(self, interaction: discord.Interaction, message: discord.Message) -> None:
user = message.author
await self.__ensure_user(user)
phone = message.content
await self.__set_phone(interaction, phone, user, ephemeral=True)
async def set_qr_from_message(self, interaction: discord.Interaction, message: discord.Message) -> None:
user = message.author
await self.__ensure_user(user)
# TODO: check if the attachment is in the message

# slash and context commands ====================================================================================================

@commands.hybrid_command()
Expand Down Expand Up @@ -154,6 +133,7 @@ async def pay(self, interaction: discord.Interaction, amount: float = 0.00, user
await self.__ensure_user(user)
content = f"```Paying {user.display_name} {amount} Baht```" if amount else f"Paying {user.display_name}"
if (qr:= await self.bot.database.get_user_qr(user)) != '0' and image_path_check(qr):
print('qr found')
await interaction.response.send_message(content=content, file=discord.File(fp=qr, filename="qr.png"))
elif (p := await self.bot.database.get_user_phone(user)) != '0':
await interaction.response.send_message(content=content, file=discord.File(fp=PromptPay.to_byte_QR(p, amount), filename="qr.png"))
Expand Down
31 changes: 31 additions & 0 deletions database_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
if __name__ == "__main__":
import os, asyncio
import asyncpg
from settings import QR_PATH, credentials

async def main():
if not os.path.exists("data"):
os.mkdir("data")
print("Created data directory.")
if not os.path.exists(QR_PATH):
os.mkdir(QR_PATH)
print("Created QR code directory.")

async with asyncpg.create_pool(**credentials) as pool:

await pool.execute('''CREATE TABLE IF NOT EXISTS guilds (
id numeric PRIMARY KEY,
name text
);''')
await pool.execute('''CREATE TABLE IF NOT EXISTS users (
id numeric PRIMARY KEY,
guild_id numeric NOT NULL,
username text NOT NULL,
phone_number text DEFAULT '0',
promptpay_qr text DEFAULT '0',
CONSTRAINT fk_guild_id
FOREIGN KEY (guild_id)
REFERENCES guilds(id)
ON DELETE CASCADE
);''')
asyncio.run(main())
16 changes: 4 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@
import asyncio

from discord.ext import commands, tasks
from dotenv import load_dotenv
from random import choice
from aiohttp import ClientSession

from utils.data_manager import Database_Manager
from settings import description, get_activity, initial_extensions, activities
from settings import description, credentials, initial_extensions, activities

load_dotenv()

print(os.environ.get("DB_PASS"))
# print(os.environ.get("DB_PASS"))

class Oppy(commands.Bot):
def __init__(self, database_manager: Database_Manager, testing_guild_id: int = None):
Expand Down Expand Up @@ -47,7 +43,7 @@ async def setup_hook(self):
# This would also be a good place to connect to our database and
# load anything that should be in memory prior to handling events.
# Basically: don't 👏 do 👏 shit 👏 in 👏 on_ready. -R. Danny
async for guild in self.fetch_guilds(limit=None):
async for guild in self.fetch_guilds(limit=None): # create all guild that this bot is in
await self.database.create_guild(guild)

# In this case, we are using this to ensure that once we are connected, we sync for the testing guild.
Expand Down Expand Up @@ -93,11 +89,7 @@ async def main():

# Here we have a database pool, which do cleanup at exit.
# We also have our bot, which depends on both of these.
async with Database_Manager(host=os.environ.get("DB_HOST"), # database connection
port=int(os.environ.get("DB_PORT")),
user=os.environ.get("DB_USER"),
password=os.environ.get("DB_PASS"),
db=os.environ.get("DB_NAME")) as db_manager:
async with Database_Manager(**credentials) as db_manager:
# 2. We become responsible for starting the bot.
async with Oppy(database_manager=db_manager) as bot:
await bot.start(os.environ.get("TOKEN"))
Expand Down
11 changes: 11 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import os
import glob
from dotenv import load_dotenv

load_dotenv()


def get_activity(activity_type: typing.Literal['gaming', 'listening', 'watching', 'competing', 'streaming'], activity_name: str, url: str = "https://www.youtube.com/watch?v=8VgSyKl9vg0"):
Expand All @@ -29,3 +32,11 @@ def get_activity(activity_type: typing.Literal['gaming', 'listening', 'watching'
get_activity('streaming', 'shrek 2',
url='https://www.youtube.com/watch?v=8VgSyKl9vg0'),
get_activity('listening', 'theory will only take you so far')]

QR_PATH = "data/qr_codes"

credentials = {'host': os.environ.get("DB_HOST"), # database connection
'port': int(os.environ.get("DB_PORT")),
'user': "zu2",
'password': 'Oppy987',
'database': 'oppybot'}
151 changes: 43 additions & 108 deletions utils/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,152 +1,87 @@
from __future__ import annotations
import aiomysql
import asyncpg
import asyncio
import discord
import logging
import os
import os, sys, traceback
from typing import Coroutine, Optional, Union

from .pattern_check import phone_check
from .promptpay import PromptPay


QR_PATH = "data/qr_codes"



# logging.basicConfig(filename='logs/discord.log', level=logging.DEBUG)


def log_data(func):
def wrapper(*args, **kwargs):
logging.debug(
f"Function {func.__name__} called. get_data() returned {args[0].get_data()}.")
return func(*args, **kwargs)
return wrapper


class Database_Manager:
# ============================ init ========================================
def __init__(self, **kwargs) -> None:
self.__auth = kwargs
if not os.path.exists("data"):
os.mkdir("data")
logging.debug("Created data directory.")
if not os.path.exists(QR_PATH):
os.mkdir(QR_PATH)
logging.debug("Created QR code directory.")


async def __aenter__(self) -> Database_Manager:
self.loop = asyncio.get_event_loop()
self.pool = await aiomysql.create_pool(loop=self.loop, **self.__auth)
self.pool = await asyncpg.create_pool(**self.__auth, loop=self.loop)
# ensure that the database exists
await self.create_database()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Close the database connection at exit."""
self.pool.close()
await self.pool.wait_closed()
await self.pool.close()
print("Database connection closed.")

# ============================ helper ========================================
async def __execute(self, query: str, *args) -> None:
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(query, args)
await conn.commit()

async def __fetchall(self, query: str, *args) -> list:
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(query, args)
return await cur.fetchall()

async def __fetchone(self, query: str, *args, to_dict:bool = True) -> Optional[tuple]:
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor if to_dict else aiomysql.Cursor) as cur:
await cur.execute(query, args)
return await cur.fetchone()

async def __create_table(self, table: str, columns: str) -> None:
await self.__execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns})")


# =========================== create functions ================================

async def create_database(self):
"""database per guild"""
await self.__create_table("guilds", """id BIGINT NOT NULL PRIMARY KEY,
guild_name VARCHAR(100) NOT NULL""")
await self.__create_table("users", """id BIGINT NOT NULL,
guild_id BIGINT NOT NULL,
user_name VARCHAR(32) NOT NULL,
phone_number VARCHAR(20) DEFAULT '0',
promptpay_qr VARCHAR(255) DEFAULT '0',
FOREIGN KEY (guild_id) REFERENCES guilds(id) ON DELETE CASCADE""")

async def create_guild(self, guild: discord.Guild):
try:
await self.__execute("INSERT INTO guilds (id, guild_name) VALUES (%s, %s)",
guild.id, guild.name)
logging.info(f"Created guild {guild.name} in database")
except aiomysql.IntegrityError:
logging.info(f"Guild {guild.name} already exists in database")



async def create_user(self, user: discord.Member):
"""create user in database"""
async def create_guild(self, guild: discord.Guild):
"""create guild"""
try:
await self.__execute("INSERT INTO users (id, guild_id, user_name) VALUES (%s, %s, %s)",
user.id, user.guild.id, user.name)
except aiomysql.IntegrityError:
pass

# =========================== read functions ================================
async def get_guild(self, guild_id: Optional[int] = None):
if guild_id:
return await self.__fetchone("SELECT * FROM guilds WHERE id = %s", guild_id)
return await self.__fetchall("SELECT * FROM guilds")
await self.pool.execute('''INSERT INTO guilds (id, name) VALUES ($1, $2)
ON CONFLICT (id)
DO NOTHING;''', int(guild.id), guild.name)
except Exception as e:
logging.error(traceback.format_exc())

async def create_user(self, user: discord.User):
"""create user"""
try:
await self.pool.execute('''INSERT INTO users (id, guild_id, username) VALUES ($1, $2, $3)
ON CONFLICT (id)
DO NOTHING;''', int(user.id), int(user.guild.id), user.name)
except Exception as e:
logging.error(traceback.format_exc())

async def get_user(self, user: Optional[discord.Member]=None, guild_id: Optional[int]=None) -> Union[dict, list[dict], None]:
"""get user"""
if user:
return await self.__fetchone("SELECT * FROM users WHERE id = %s AND guild_id = %s", user.id, user.guild.id)
return await self.pool.fetchrow('''SELECT * FROM users WHERE id = $1 AND guild_id = $2;''', int(user.id), int(user.guild.id))
elif guild_id:
return await self.__fetchall("SELECT * FROM users WHERE guild_id = %s", guild_id)
return await self.__fetchall("SELECT * FROM users")

async def get_user_phone(self, user: discord.Member) -> str:
if (phone := await self.__fetchone("SELECT phone_number FROM users WHERE id = %s AND guild_id = %s", user.id, user.guild.id, to_dict=False)):
return phone[0]
return None
return await self.pool.fetch('''SELECT * FROM users WHERE guild_id = $1;''', int(guild_id))
return await self.pool.fetch('''SELECT * FROM users;''')

async def get_user_qr(self, user: discord.Member) -> str:
if (qr := await self.__fetchone("SELECT promptpay_qr FROM users WHERE id = %s AND guild_id = %s", user.id, user.guild.id, to_dict=False)):
return qr[0]
return None


# =========================== update functions ================================
async def get_user_phone(self, user: discord.User) -> str:
"""get user phone"""
return await self.pool.fetchval('''SELECT phone_number FROM users WHERE id = $1 AND guild_id = $2;''', int(user.id), int(user.guild.id))

async def get_user_qr(self, user: discord.User) -> str:
"""get user promptpay"""
if (qr := await self.pool.fetchval('''SELECT promptpay_qr FROM users WHERE id = $1 AND guild_id = $2;''', int(user.id), int(user.guild.id))):
return qr
return '0'

# =========================== update functions ================================
async def set_phone(self, user: discord.Member, phone: str) -> [False, str]:
if not (p := phone_check(phone)):
return False
await self.__execute("UPDATE users SET phone_number = %s WHERE id = %s AND guild_id = %s", p, user.id, user.guild.id)
await self.__execute("UPDATE users SET promptpay_qr = %s WHERE id = %s AND guild_id = %s", str(PromptPay(p)), user.id, user.guild.id)

await self.pool.execute("UPDATE users SET phone_number = $1 WHERE id = $2 AND guild_id = $3", p, user.id, user.guild.id)
await self.pool.execute("UPDATE users SET promptpay_qr = $1 WHERE id = $2 AND guild_id = $3", str(PromptPay(p)), user.id, user.guild.id)
return p

async def set_promptpay_qr(self, user: discord.Member, qr: str):
await self.__execute("UPDATE users SET promptpay_qr = %s WHERE id = %s AND guild_id = %s", qr, user.id, user.guild.id)

# =========================== delete functions ================================




await self.pool.execute("UPDATE users SET promptpay_qr = $1 WHERE id = $1 AND guild_id = $3", qr, user.id, user.guild.id)

if __name__ == "__main__":
# db = DB_Manager()
# # db.create_user(u)
# u = User(123, "test", "0", "0")
# db.create_user(u)
# db.update_phone(123, "1234567890")
pass

3 changes: 1 addition & 2 deletions utils/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, bot: 'Oppy', view: Show_User_View, all_users: dict, ephemeral
# dropdown menus
# using guild to fetch member instead of bot.get_user() so it will only show users in the guild
options = [discord.SelectOption(label=self.ctx.guild.get_member(int(
user['id'])).display_name, emoji=self.bot.get_emoji(911502994468651010), value=user['id']) for user in all_users]
user.get('id'))).display_name, emoji=self.bot.get_emoji(911502994468651010), value=str(user.get('id'))) for user in all_users]

super().__init__(placeholder='Choose your target...',
min_values=1, max_values=1, options=options)
Expand Down Expand Up @@ -50,7 +50,6 @@ def __init__(self, ctx: commands.Context, bot: 'Oppy', user: discord.User, all_u
self.ctx = ctx
self.user = user
self.bot = bot

self.add_item(Show_User_Dropdown(bot, self, all_users, ephemeral))

@discord.ui.button(label="phone", style=discord.ButtonStyle.gray)
Expand Down

0 comments on commit a6ba5b0

Please sign in to comment.