From 311cb39c19321d95752838d419a635336afc93c0 Mon Sep 17 00:00:00 2001 From: Soumik Sarkar Date: Tue, 30 Apr 2019 21:41:13 +0530 Subject: [PATCH] Split the database into one for user and one for cache. (#34) * Split the database into one for user and one for cache. * Minor improvements. --- tle/__main__.py | 6 +- tle/cogs/codeforces.py | 28 ++-- tle/cogs/future_contests.py | 14 +- tle/cogs/graphs.py | 2 +- tle/cogs/handles.py | 20 +-- tle/constants.py | 4 +- tle/util/cache_system.py | 10 +- tle/util/cache_system2.py | 10 +- tle/util/codeforces_common.py | 38 ++++-- tle/util/db/__init__.py | 2 + tle/util/db/cache_db_conn.py | 128 ++++++++++++++++++ .../{handle_conn.py => db/user_db_conn.py} | 67 +-------- tle/util/discord_common.py | 4 +- 13 files changed, 202 insertions(+), 131 deletions(-) create mode 100644 tle/util/db/__init__.py create mode 100644 tle/util/db/cache_db_conn.py rename tle/util/{handle_conn.py => db/user_db_conn.py} (81%) diff --git a/tle/__main__.py b/tle/__main__.py index 3058ac44..1b79a954 100644 --- a/tle/__main__.py +++ b/tle/__main__.py @@ -64,11 +64,7 @@ def no_dm_check(ctx): @bot.event async def on_ready(): - if args.nodb: - dbfile = None - else: - dbfile = os.path.join(constants.FILEDIR, constants.DB_FILENAME) - await cf_common.initialize(dbfile, constants.CONTEST_CACHE_PERIOD) + await cf_common.initialize(args.nodb) bot.add_listener(discord_common.bot_error_handler, name='on_command_error') diff --git a/tle/cogs/codeforces.py b/tle/cogs/codeforces.py index 756f7d00..fa7d6bd1 100644 --- a/tle/cogs/codeforces.py +++ b/tle/cogs/codeforces.py @@ -19,7 +19,7 @@ def __init__(self, bot): @commands.has_role('Admin') async def _updatestatus(self, ctx): active_ids = [m.id for m in ctx.guild.members] - rc = cf_common.conn.update_status(active_ids) + rc = cf_common.user_db.update_status(active_ids) await ctx.send(f'{rc} members active with handle') @commands.command(brief='force cache refresh of contests and problems') @@ -37,7 +37,7 @@ async def cache_cfuser_subs(self, handle: str): solved = {prob.contest_identifier for prob in solved if prob.has_metadata()} solved = json.dumps(list(solved)) stamp = time.time() - cf_common.conn.cache_cfuser_full(info + (solved, stamp)) + cf_common.user_db.cache_cfuser_full(info + (solved, stamp)) return stamp, info.rating, solved @commands.command(brief='Recommend a problem') @@ -51,7 +51,7 @@ async def gimme(self, ctx, *args): bounds.append(int(arg)) else: tags.append(arg) - handle = cf_common.conn.gethandle(ctx.message.author.id) + handle = cf_common.user_db.gethandle(ctx.message.author.id) rating, solved = None, None if handle: @@ -97,11 +97,11 @@ async def gimme(self, ctx, *args): @cf_common.user_guard(group='gitgud') async def gitgud(self, ctx, delta: int = 0): user_id = ctx.message.author.id - handle = cf_common.conn.gethandle(user_id) + handle = cf_common.user_db.gethandle(user_id) if not handle: await ctx.send('You must link your handle to be able to use this feature.') return - active = cf_common.conn.check_challenge(user_id) + active = cf_common.user_db.check_challenge(user_id) if active is not None: challenge_id, issue_time, name, contest_id, index, c_delta = active url = f'{cf.CONTEST_BASE_URL}{contest_id}/problem/{index}' @@ -138,7 +138,7 @@ def check(problem): issue_time = datetime.datetime.now().timestamp() - rc = cf_common.conn.new_challenge(user_id, issue_time, problem, delta) + rc = cf_common.user_db.new_challenge(user_id, issue_time, problem, delta) if rc != 1: # await ctx.send('Error updating the database') await ctx.send('Your challenge has already been added to the database!') @@ -154,11 +154,11 @@ def check(problem): @cf_common.user_guard(group='gitgud') async def gotgud(self, ctx): user_id = ctx.message.author.id - handle = cf_common.conn.gethandle(user_id) + handle = cf_common.user_db.gethandle(user_id) if not handle: await ctx.send('You must link your handle to be able to use this feature.') return - active = cf_common.conn.check_challenge(user_id) + active = cf_common.user_db.check_challenge(user_id) if not active: await ctx.send(f'You do not have an active challenge') return @@ -172,7 +172,7 @@ async def gotgud(self, ctx): return delta = delta // 100 + 3 finish_time = int(datetime.datetime.now().timestamp()) - rc = cf_common.conn.complete_challenge(user_id, challenge_id, finish_time, delta) + rc = cf_common.user_db.complete_challenge(user_id, challenge_id, finish_time, delta) if rc == 1: await ctx.send(f'Challenge completed. {handle} gained {delta} points.') else: @@ -182,11 +182,11 @@ async def gotgud(self, ctx): @cf_common.user_guard(group='gitgud') async def nogud(self, ctx): user_id = ctx.message.author.id - handle = cf_common.conn.gethandle(user_id) + handle = cf_common.user_db.gethandle(user_id) if not handle: await ctx.send('You must link your handle to be able to use this feature.') return - active = cf_common.conn.check_challenge(user_id) + active = cf_common.user_db.check_challenge(user_id) if not active: await ctx.send(f'You do not have an active challenge') return @@ -195,14 +195,14 @@ async def nogud(self, ctx): if finish_time - issue_time < 10800: await ctx.send(f'You can\'t skip your challenge yet. Think more.') return - cf_common.conn.skip_challenge(user_id, challenge_id) + cf_common.user_db.skip_challenge(user_id, challenge_id) await ctx.send(f'Challenge skipped.') @commands.command(brief='Force skip a challenge') @cf_common.user_guard(group='gitgud') @commands.has_role('Admin') async def _nogud(self, ctx, user: str): - rc = cf_common.conn.force_skip_challenge(user) + rc = cf_common.user_db.force_skip_challenge(user) if rc == 1: await ctx.send(f'Challenge skip forced.') else: @@ -240,7 +240,7 @@ async def vc(self, ctx, *handles: str): contest_id = sorted(list(recommendations))[choice] # from and count are for ranklist, set to minimum (1) because we only need name str_handles = '`, `'.join(handles) - contest, _, _ = await cf.contest.standings(contestid=contest_id, from_=1, count=1) + contest, _, _ = await cf.contest.standings(contest_id=contest_id, from_=1, count=1) embed = discord.Embed(title=contest.name, url=contest.url) await ctx.send(f'Recommended contest for `{str_handles}`', embed=embed) diff --git a/tle/cogs/future_contests.py b/tle/cogs/future_contests.py index 28c5657e..5b8adeb3 100644 --- a/tle/cogs/future_contests.py +++ b/tle/cogs/future_contests.py @@ -9,8 +9,8 @@ from discord.ext import commands from tle.util import codeforces_common as cf_common +from tle.util import db from tle.util import discord_common -from tle.util import handle_conn from tle.util import paginator _CONTEST_RELOAD_INTERVAL = 60 * 60 # 1 hour @@ -139,8 +139,8 @@ def _reschedule_tasks(self, guild_id): if not self.start_time_map: return try: - settings = cf_common.conn.get_reminder_settings(guild_id) - except handle_conn.DatabaseDisabledError: + settings = cf_common.user_db.get_reminder_settings(guild_id) + except db.DatabaseDisabledError: return if settings is None: return @@ -215,21 +215,21 @@ async def here(self, ctx, role: discord.Role, *before: int): return if not before or any(before_mins <= 0 for before_mins in before): return - cf_common.conn.set_reminder_settings(ctx.guild.id, ctx.channel.id, role.id, json.dumps(before)) + cf_common.user_db.set_reminder_settings(ctx.guild.id, ctx.channel.id, role.id, json.dumps(before)) await ctx.send(embed=discord_common.embed_success('Reminder settings saved successfully')) self._reschedule_tasks(ctx.guild.id) @remind.command(brief='Clear all reminder settings') @commands.has_role('Admin') async def clear(self, ctx): - cf_common.conn.clear_reminder_settings(ctx.guild.id) + cf_common.user_db.clear_reminder_settings(ctx.guild.id) await ctx.send(embed=discord_common.embed_success('Reminder settings cleared')) self._reschedule_tasks(ctx.guild.id) @remind.command(brief='Show reminder settings') async def settings(self, ctx): """Shows the role, channel and before time settings.""" - settings = cf_common.conn.get_reminder_settings(ctx.guild.id) + settings = cf_common.user_db.get_reminder_settings(ctx.guild.id) if settings is None: await ctx.send(embed=discord_common.embed_neutral('Reminder not set')) return @@ -252,7 +252,7 @@ async def settings(self, ctx): @remind.command(brief='Subscribe to or unsubscribe from contest reminders', usage='[not]') async def me(self, ctx, arg: str = None): - settings = cf_common.conn.get_reminder_settings(ctx.guild.id) + settings = cf_common.user_db.get_reminder_settings(ctx.guild.id) if settings is None: await ctx.send( embed=discord_common.embed_alert('To use this command, reminder settings must be set by an admin')) diff --git a/tle/cogs/graphs.py b/tle/cogs/graphs.py index 04f91b00..1fa35520 100644 --- a/tle/cogs/graphs.py +++ b/tle/cogs/graphs.py @@ -235,7 +235,7 @@ async def scatter(self, ctx, handle: str = None, bin_size: int = 10): @plot.command(brief='Show server rating distribution') async def distrib(self, ctx): """Plots rating distribution of server members.""" - res = cf_common.conn.getallhandleswithrating() + res = cf_common.user_db.getallhandleswithrating() ratings = [rating for _, _, rating in res] bin_count = min(len(ratings), 30) diff --git a/tle/cogs/handles.py b/tle/cogs/handles.py index 7598806a..edcc9b23 100644 --- a/tle/cogs/handles.py +++ b/tle/cogs/handles.py @@ -137,8 +137,8 @@ async def sethandle(self, ctx, member: discord.Member, handle: str): # CF API returns correct handle ignoring case, update to it handle = user.handle - cf_common.conn.cache_cfuser(user) - cf_common.conn.sethandle(member.id, handle) + cf_common.user_db.cache_cfuser(user) + cf_common.user_db.sethandle(member.id, handle) embed = _make_profile_embed(member, user, mode='set') await ctx.send(embed=embed) @@ -146,11 +146,11 @@ async def sethandle(self, ctx, member: discord.Member, handle: str): @commands.command(brief='gethandle [name]') async def gethandle(self, ctx, member: discord.Member): """Show Codeforces handle of a user""" - handle = cf_common.conn.gethandle(member.id) + handle = cf_common.user_db.gethandle(member.id) if not handle: await ctx.send(f'Handle for user {member.display_name} not found in database') return - user = cf_common.conn.fetch_cfuser(handle) + user = cf_common.user_db.fetch_cfuser(handle) if user is None: # Not cached, should not happen logging.error(f'Handle info for {handle} not cached') @@ -167,7 +167,7 @@ async def removehandle(self, ctx, member: discord.Member): await ctx.send('Member not found!') return try: - r = cf_common.conn.removehandle(member.id) + r = cf_common.user_db.removehandle(member.id) if r == 1: msg = f'removehandle: {member.name} removed' else: @@ -181,7 +181,7 @@ async def removehandle(self, ctx, member: discord.Member): async def gudgitters(self, ctx): try: converter = commands.MemberConverter() - res = cf_common.conn.get_gudgitters() + res = cf_common.user_db.get_gudgitters() res.sort(key=lambda r: r[1], reverse=True) style = table.Style('{:>} {:<}') @@ -209,7 +209,7 @@ async def showhandles(self, ctx): """Shows all members of the server who have registered their handles and their Codeforces ratings. """ - res = cf_common.conn.getallhandleswithrating() + res = cf_common.user_db.getallhandleswithrating() users = [(ctx.guild.get_member(int(user_id)), handle, rating) for user_id, handle, rating in res] users = [(member, handle, rating) for member, handle, rating in users if member is not None] users.sort(key=lambda x: (1 if x[2] is None else -x[2], x[1])) # Sorting by (-rating, handle) @@ -220,7 +220,7 @@ async def showhandles(self, ctx): async def prettyhandles(self, ctx: discord.ext.commands.Context, page_no: int = None): try: converter = commands.MemberConverter() - res = cf_common.conn.getallhandleswithrating() + res = cf_common.user_db.getallhandleswithrating() res.sort(key=lambda r: r[2] if r[2] is not None else -1, reverse=True) rankings = [] pos = 0 @@ -278,13 +278,13 @@ async def _updateroles(self, ctx): return try: - res = cf_common.conn.getallhandles() + res = cf_common.user_db.getallhandles() handles = [handle for _, handle in res] users = await cf.user.info(handles=handles) await ctx.send('caching handles...') try: for user in users: - cf_common.conn.cache_cfuser(user) + cf_common.user_db.cache_cfuser(user) except Exception as e: print(e) except Exception as e: diff --git a/tle/constants.py b/tle/constants.py index 843083d2..150e9223 100644 --- a/tle/constants.py +++ b/tle/constants.py @@ -1,4 +1,4 @@ FILEDIR = './files' -DB_FILENAME = 'handles.db' -CONTEST_CACHE_PERIOD = 3600 +USER_DB_FILENAME = 'user.db' +CACHE_DB_FILENAME = 'cache.db' CONTEST_WRITERS_JSON_FILE = 'contest_writers.json' diff --git a/tle/util/cache_system.py b/tle/util/cache_system.py index a1dbe6a3..724f932f 100644 --- a/tle/util/cache_system.py +++ b/tle/util/cache_system.py @@ -6,7 +6,7 @@ import time from tle.util import codeforces_api as cf -from tle.util import handle_conn +from tle.util import db logger = logging.getLogger(__name__) @@ -74,7 +74,7 @@ async def force_update(self): await self.cache_problems() def try_disk(self): - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): contests = self.conn.fetch_contests() problem_res = self.conn.fetch_problems() if not contests or not problem_res: @@ -102,7 +102,7 @@ async def cache_contests(self): } self.contest_last_cache = time.time() self.logger.info(f'{len(self.contest_dict)} contests cached') - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): rc = self.conn.cache_contests(contests) self.logger.info(f'{rc} contests stored in database') @@ -125,7 +125,7 @@ async def cache_problems(self): } self.problems_last_cache = time.time() self.logger.info(f'{len(self.problem_dict)} problems cached') - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): rc = self.conn.cache_problems([ ( prob.name, prob.contestId, prob.index, @@ -140,7 +140,7 @@ async def cache_problems(self): async def get_rating_solved(self, handle: str, time_out: int): cached = self._user_rating_solved(handle) stamp, rating, solved = cached - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): if stamp is None: # Try from disk first stamp, rating, solved = await self._retrieve_rating_solved(handle) diff --git a/tle/util/cache_system2.py b/tle/util/cache_system2.py index c9bc2a70..eeaea998 100644 --- a/tle/util/cache_system2.py +++ b/tle/util/cache_system2.py @@ -8,7 +8,7 @@ from tle.util import codeforces_common as cf_common from tle.util import codeforces_api as cf -from tle.util import handle_conn +from tle.util import db logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ async def reload_now(self): raise self.reload_exception async def _try_disk(self): - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): async with self.reload_lock: contests = self.cache_master.conn.fetch_contests() if not contests: @@ -99,7 +99,7 @@ async def _update(self, contests, from_api=True): contests.sort(key=lambda contest: (contest.startTimeSeconds, contest.id)) if from_api: - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): rc = self.cache_master.conn.cache_contests(contests) self.logger.info(f'{rc} contests stored in database') @@ -172,7 +172,7 @@ async def reload_now(self): raise self.reload_exception async def _try_disk(self): - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): async with self.reload_lock: problem_res = self.cache_master.conn.fetch_problems() if not problem_res: @@ -229,7 +229,7 @@ def keep(problem): self.problem_start = problem_start self.problems_last_cache = time.time() - with suppress(handle_conn.DatabaseDisabledError): + with suppress(db.DatabaseDisabledError): def get_tuple_repr(problem): return (problem.name, problem.contestId, diff --git a/tle/util/codeforces_common.py b/tle/util/codeforces_common.py index 26db9988..b2d97c0c 100644 --- a/tle/util/codeforces_common.py +++ b/tle/util/codeforces_common.py @@ -2,22 +2,23 @@ import functools import json import logging +import os from collections import defaultdict from discord.ext import commands from tle import constants +from tle.util import cache_system2 from tle.util import codeforces_api as cf +from tle.util import db from tle.util import discord_common -from tle.util import handle_conn from tle.util import event_system -from tle.util import cache_system2 from tle.util.cache_system import CacheSystem logger = logging.getLogger(__name__) # Connection to database -conn = None +user_db = None # Cache system cache = None @@ -31,18 +32,20 @@ active_groups = defaultdict(set) -async def initialize(dbfile, cache_refresh_interval): +async def initialize(nodb): global cache global cache2 - global conn + global user_db global event_sys global _contest_id_to_writers_map - if dbfile is None: - conn = handle_conn.DummyConn() + if nodb: + user_db = db.DummyUserDbConn() else: - conn = handle_conn.HandleConn(dbfile) - cache = CacheSystem(conn) + user_db_file = os.path.join(constants.FILEDIR, constants.USER_DB_FILENAME) + user_db = db.UserDbConn(user_db_file) + + cache = CacheSystem(user_db) # Initial fetch from CF API await cache.force_update() if cache.contest_last_cache and cache.problems_last_cache: @@ -51,12 +54,14 @@ async def initialize(dbfile, cache_refresh_interval): # If fetch failed, load from disk logger.info('Loading cache from disk') cache.try_disk() - asyncio.create_task(_cache_refresher_task(cache_refresh_interval)) + asyncio.create_task(_cache_refresher_task()) - cache2 = cache_system2.CacheSystem(conn) + cache_db_file = os.path.join(constants.FILEDIR, constants.CACHE_DB_FILENAME) + cache_db = db.CacheDbConn(cache_db_file) + cache2 = cache_system2.CacheSystem(cache_db) await cache2.run() - jsonfile = f'{constants.FILEDIR}/{constants.CONTEST_WRITERS_JSON_FILE}' + jsonfile = os.path.join(constants.FILEDIR, constants.CONTEST_WRITERS_JSON_FILE) try: with open(jsonfile) as f: data = json.load(f) @@ -88,9 +93,12 @@ async def f(self, ctx, *args, **kwargs): return guard -async def _cache_refresher_task(refresh_interval): +_CACHE_REFRESH_INTERVAL = 60 * 60 + + +async def _cache_refresher_task(): while True: - await asyncio.sleep(refresh_interval) + await asyncio.sleep(_CACHE_REFRESH_INTERVAL) logger.info('Attempting cache refresh') await cache.force_update() @@ -169,7 +177,7 @@ async def resolve_handles(ctx, converter, handles, *, mincnt=1, maxcnt=5): member = await converter.convert(ctx, member_identifier) except commands.errors.CommandError: raise FindMemberFailedError(member_identifier) - handle = conn.gethandle(member.id) + handle = user_db.gethandle(member.id) if handle is None: raise HandleNotRegisteredError(member) if handle in HandleIsVjudgeError.HANDLES: diff --git a/tle/util/db/__init__.py b/tle/util/db/__init__.py new file mode 100644 index 00000000..dfe49d8c --- /dev/null +++ b/tle/util/db/__init__.py @@ -0,0 +1,2 @@ +from .cache_db_conn import CacheDbConn +from .user_db_conn import DummyUserDbConn, UserDbConn, DatabaseDisabledError diff --git a/tle/util/db/cache_db_conn.py b/tle/util/db/cache_db_conn.py new file mode 100644 index 00000000..fc2b1e1d --- /dev/null +++ b/tle/util/db/cache_db_conn.py @@ -0,0 +1,128 @@ +import sqlite3 + +from tle.util import codeforces_api as cf + + +class CacheDbConn: + def __init__(self, db_file): + self.conn = sqlite3.connect(db_file) + self.create_tables() + + def create_tables(self): + self.conn.execute( + 'CREATE TABLE IF NOT EXISTS contest (' + 'id INTEGER NOT NULL,' + 'name TEXT,' + 'start_time INTEGER,' + 'duration INTEGER,' + 'type TEXT,' + 'phase TEXT,' + 'prepared_by TEXT,' + 'PRIMARY KEY (id)' + ')' + ) + self.conn.execute( + 'CREATE TABLE IF NOT EXISTS problem (' + 'name TEXT NOT NULL,' + 'contest_id INTEGER,' + 'p_index TEXT,' + 'start_time INTEGER,' + 'rating INTEGER,' + 'type TEXT,' + 'tags TEXT,' + 'PRIMARY KEY (name)' + ')' + ) + self.conn.execute( + 'CREATE TABLE IF NOT EXISTS rating_change (' + 'contest_id INTEGER NOT NULL,' + 'handle TEXT NOT NULL,' + 'rank INTEGER,' + 'rating_update_time INTEGER,' + 'old_rating INTEGER,' + 'new_rating INTEGER,' + 'UNIQUE (contest_id, handle)' + ')' + ) + self.conn.execute('CREATE INDEX IF NOT EXISTS ix_rating_change_contest_id ' + 'ON rating_change (contest_id)') + self.conn.execute('CREATE INDEX IF NOT EXISTS ix_rating_change_handle ' + 'ON rating_change (handle)') + + def fetch_contests(self): + query = ('SELECT id, name, start_time, duration, type, phase, prepared_by ' + 'FROM contest') + res = self.conn.execute(query).fetchall() + return [cf.Contest._make(contest) for contest in res] + + def fetch_problems(self): + query = ('SELECT contest_id, p_index, name, type, rating, tags, start_time ' + 'FROM problem') + res = self.conn.execute(query).fetchall() + return [(cf.Problem._make(problem[:6]), problem[6]) for problem in res] + + def cache_contests(self, contests): + query = ('INSERT OR REPLACE INTO contest ' + '(id, name, start_time, duration, type, phase, prepared_by) ' + 'VALUES (?, ?, ?, ?, ?, ?, ?)') + rc = self.conn.executemany(query, contests).rowcount + self.conn.commit() + return rc + + def cache_problems(self, problems): + query = ('INSERT OR REPLACE INTO problem ' + '(name, contest_id, p_index, start_time, rating, type, tags) ' + 'VALUES (?, ?, ?, ?, ?, ?, ?)') + rc = self.conn.executemany(query, problems).rowcount + self.conn.commit() + return rc + + def save_rating_changes(self, changes): + change_tuples = [(change.contestId, + change.handle, + change.rank, + change.ratingUpdateTimeSeconds, + change.oldRating, + change.newRating) for change in changes] + query = ('INSERT OR REPLACE INTO rating_change ' + '(contest_id, handle, rank, rating_update_time, old_rating, new_rating) ' + 'VALUES (?, ?, ?, ?, ?, ?)') + rc = self.conn.executemany(query, change_tuples).rowcount + self.conn.commit() + return rc + + def get_all_rating_changes(self): + query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' + 'FROM rating_change r ' + 'LEFT JOIN contest c ' + 'ON r.contest_id = c.id') + res = self.conn.execute(query).fetchall() + return [cf.RatingChange._make(change) for change in res] + + def get_rating_changes_for_contest(self, contest_id): + query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' + 'FROM rating_change r ' + 'LEFT JOIN contest c ' + 'ON r.contest_id = c.id ' + 'WHERE r.contest_id = ?') + res = self.conn.execute(query, (contest_id,)).fetchall() + return [cf.RatingChange._make(change) for change in res] + + def has_rating_changes_saved(self, contest_id): + query = ('SELECT contest_id ' + 'FROM rating_change ' + 'WHERE contest_id = ?') + res = self.conn.execute(query, (contest_id,)).fetchone() + return res is not None + + def get_rating_changes_for_handle(self, handle): + query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' + 'FROM rating_change r ' + 'LEFT JOIN contest c ' + 'ON r.contest_id = c.id ' + 'WHERE r.handle = ?') + res = self.conn.execute(query, (handle,)).fetchall() + return [cf.RatingChange._make(change) for change in res] + + def close(self): + self.conn.close() diff --git a/tle/util/handle_conn.py b/tle/util/db/user_db_conn.py similarity index 81% rename from tle/util/handle_conn.py rename to tle/util/db/user_db_conn.py index 1024b34c..8cf53dbf 100644 --- a/tle/util/handle_conn.py +++ b/tle/util/db/user_db_conn.py @@ -10,12 +10,12 @@ class DatabaseDisabledError(commands.CommandError): pass -class DummyConn: +class DummyUserDbConn: def __getattribute__(self, item): raise DatabaseDisabledError -class HandleConn: +class UserDbConn: def __init__(self, dbfile): self.conn = sqlite3.connect(dbfile) self.create_tables() @@ -96,21 +96,6 @@ def create_tables(self): before TEXT ) ''') - self.conn.execute(''' - CREATE TABLE IF NOT EXISTS rating_changes( - contest_id INTEGER NOT NULL, - handle TEXT NOT NULL, - rank INTEGER, - rating_update_time INTEGER, - old_rating INTEGER, - new_rating INTEGER, - UNIQUE (contest_id, handle) - ) - ''') - self.conn.execute( - 'CREATE INDEX IF NOT EXISTS ix_rating_changes_contest_id ON rating_changes (contest_id)') - self.conn.execute( - 'CREATE INDEX IF NOT EXISTS ix_rating_changes_handle ON rating_changes (handle)') def fetch_contests(self): query = 'SELECT id, name, start_time, duration, type, phase, prepared_by FROM contest' @@ -367,53 +352,5 @@ def clear_reminder_settings(self, guild_id): self.conn.execute(query, (guild_id,)) self.conn.commit() - def save_rating_changes(self, changes): - change_tuples = [(change.contestId, - change.handle, - change.rank, - change.ratingUpdateTimeSeconds, - change.oldRating, - change.newRating) for change in changes] - return self._insert_many('rating_changes', - 'contest_id handle rank rating_update_time old_rating new_rating'.split(), - change_tuples) - - def get_all_rating_changes(self): - query = ''' - SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating - FROM rating_changes r - LEFT JOIN contest c - ON r.contest_id = c.id - ''' - res = self.conn.execute(query).fetchall() - return [cf.RatingChange._make(change) for change in res] - - def get_rating_changes_for_contest(self, contest_id): - query = ''' - SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating - FROM rating_changes r - LEFT JOIN contest c - ON r.contest_id = c.id - WHERE r.contest_id = ? - ''' - res = self.conn.execute(query, (contest_id,)).fetchall() - return [cf.RatingChange._make(change) for change in res] - - def has_rating_changes_saved(self, contest_id): - query = 'SELECT contest_id FROM rating_changes WHERE contest_id = ?' - res = self.conn.execute(query, (contest_id,)).fetchone() - return res is not None - - def get_rating_changes_for_handle(self, handle): - query = ''' - SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating - FROM rating_changes r - LEFT JOIN contest c - ON r.contest_id = c.id - WHERE r.handle = ? - ''' - res = self.conn.execute(query, (handle,)).fetchall() - return [cf.RatingChange._make(change) for change in res] - def close(self): self.conn.close() diff --git a/tle/util/discord_common.py b/tle/util/discord_common.py index 7d73dcab..0cbd6d73 100644 --- a/tle/util/discord_common.py +++ b/tle/util/discord_common.py @@ -4,7 +4,7 @@ import discord from discord.ext import commands -from tle.util import handle_conn +from tle.util import db logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ async def bot_error_handler(ctx, exception): # Errors already handled in cogs should have .handled = True return - if isinstance(exception, handle_conn.DatabaseDisabledError): + if isinstance(exception, db.DatabaseDisabledError): await ctx.send(embed=embed_alert('Sorry, the database is not available. Some features are disabled.')) elif isinstance(exception, commands.NoPrivateMessage): await ctx.send(embed=embed_alert('Commands are disabled in private channels'))