From b3a1dc7465d159a90b1c73fbe3029357ba1af50b Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 27 Feb 2024 03:40:15 +0400 Subject: [PATCH] Drawing output result (#10) * drawing works fine * ref cleaning --- pyproject.toml | 3 +- src/config.py | 2 +- src/data_models/Game.py | 18 ++- src/data_models/Player.py | 3 +- src/data_models/Poll.py | 3 +- src/data_models/Record.py | 13 +- src/db.py | 14 +- src/handlers/game.py | 6 +- src/handlers/save.py | 64 +++++---- src/services/db_service.py | 6 - src/services/draw_result_image.py | 229 ++++++++++++++++++++++++++++++ 11 files changed, 307 insertions(+), 54 deletions(-) create mode 100644 src/services/draw_result_image.py diff --git a/pyproject.toml b/pyproject.toml index 79c51d3..789ff14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,12 @@ authors = ["alex"] readme = "README.md" [tool.poetry.dependencies] -python = "^3.12" +python = ">=3.11,<3.13" aiosqlite = "^0.19.0" python-dotenv = "^1.0.1" python-telegram-bot = "^20.7" pydantic = "^2.5.3" +cairosvg = "^2.7.1" [build-system] diff --git a/src/config.py b/src/config.py index 37df29e..19ede4e 100644 --- a/src/config.py +++ b/src/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Literal, List, Tuple +from typing import Tuple from dotenv import load_dotenv diff --git a/src/data_models/Game.py b/src/data_models/Game.py index 7540e4b..1cca916 100644 --- a/src/data_models/Game.py +++ b/src/data_models/Game.py @@ -1,8 +1,9 @@ -from datetime import datetime from typing import Literal from pydantic import BaseModel, field_validator +from src import config + class Game(BaseModel): poll_id: int @@ -18,7 +19,18 @@ def validate_results(cls, v: dict) -> Literal["CH", "DH", "FW", "LW"]: return "CH" if "I'm Dead Hitler" in outcomes: return "DH" - if "I'm Liberal Winner" in outcomes: + if ( + "I'm Liberal Winner" + or "I'm Hitler Looser" + or "I'm Fascistic Looser" in outcomes + ): return "LW" - if "I'm Fascistic Winner" in outcomes: + if ( + "I'm Fascistic Winner" + or "I'm Hitler Winner" + or "I'm Liberal Looser" in outcomes + ): return "FW" + raise ValueError( + f"Invalid results '{v}' for Game. Results must be one of {config.GAME_POLL_OUTCOMES}" + ) diff --git a/src/data_models/Player.py b/src/data_models/Player.py index 0a6cd65..7c8f182 100644 --- a/src/data_models/Player.py +++ b/src/data_models/Player.py @@ -1,5 +1,4 @@ from typing import Optional - from pydantic import BaseModel, field_validator @@ -15,4 +14,4 @@ class Player(BaseModel): @field_validator("is_bot", mode="after") @classmethod def validate_bot(cls, v: bool) -> str: - return "TRUE" if v else "FALSE" # sqlite3 does not support boolean type + return "TRUE" if v else "FALSE" # sqlite3 does not support a boolean type diff --git a/src/data_models/Poll.py b/src/data_models/Poll.py index af47ef2..cb2bade 100644 --- a/src/data_models/Poll.py +++ b/src/data_models/Poll.py @@ -1,7 +1,6 @@ -from datetime import datetime +from typing import Literal from pydantic import BaseModel -from typing import Literal class Poll(BaseModel): diff --git a/src/data_models/Record.py b/src/data_models/Record.py index c010501..fc64a0b 100644 --- a/src/data_models/Record.py +++ b/src/data_models/Record.py @@ -1,8 +1,9 @@ -from enum import Enum from typing import Literal, Optional -from src import config + from pydantic import BaseModel, field_validator +from src import config + class Record(BaseModel): creator_id: int @@ -39,3 +40,11 @@ def shorten_role( raise ValueError( f"Invalid role '{v}' for Record. Role must be one of {config.GAME_POLL_OUTCOMES}" ) + + def get_team(self) -> Optional[Literal["Fascist", "Liberal"]]: + if self.role in {"CH", "DH", "FW", "FL", "HL"}: + return "Fascist" + elif self.role in {"LW", "LL"}: + return "Liberal" + else: + return None diff --git a/src/db.py b/src/db.py index a1987da..c6affef 100644 --- a/src/db.py +++ b/src/db.py @@ -7,16 +7,16 @@ async def get_db() -> aiosqlite.Connection: if not getattr(get_db, "db", None): - db = await aiosqlite.connect(config.SQLITE_DB_FILE_PATH, - timeout=60 * 60 * 24 * 1 # 1 day - ) + db = await aiosqlite.connect( + config.SQLITE_DB_FILE_PATH, timeout=60 * 60 * 24 * 1 # 1 day + ) get_db.db = db return get_db.db async def fetch_all( - sql: LiteralString, params: Iterable[Any] | None = None + sql: LiteralString, params: Iterable[Any] | None = None ) -> list[dict]: cursor = await _get_cursor(sql, params) rows = await cursor.fetchall() @@ -26,7 +26,7 @@ async def fetch_all( async def fetch_one( - sql: LiteralString, params: Iterable[Any] | None = None + sql: LiteralString, params: Iterable[Any] | None = None ) -> dict | None: cursor = await _get_cursor(sql, params) row = await cursor.fetchone() @@ -36,7 +36,7 @@ async def fetch_one( async def execute( - sql: LiteralString, params: Iterable[Any] | None = None, *, autocommit: bool = True + sql: LiteralString, params: Iterable[Any] | None = None, *, autocommit: bool = True ) -> None: db = await get_db() await db.execute(sql, params) @@ -53,7 +53,7 @@ async def _async_close_db() -> None: async def _get_cursor( - sql: LiteralString, params: Iterable[Any] | None + sql: LiteralString, params: Iterable[Any] | None ) -> aiosqlite.Cursor: db = await get_db() db.row_factory = aiosqlite.Row diff --git a/src/handlers/game.py b/src/handlers/game.py index ad942ba..764a3b7 100644 --- a/src/handlers/game.py +++ b/src/handlers/game.py @@ -1,9 +1,9 @@ -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import Update from telegram.ext import ContextTypes + from src import config from src.data_models.Playroom import Playroom from src.services.db_service import save_playroom -from src.utils import message_is_poll, is_message_from_group_chat async def game(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -21,7 +21,7 @@ async def game(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: ) # Save some info about the poll the bot_data for later use in receive_poll_answer - game_metadata = { + game_metadata = { # TODO write it to DB message.poll.id: { "questions": questions, "message_id": message.id, # will be game_id diff --git a/src/handlers/save.py b/src/handlers/save.py index 93892eb..96e8769 100644 --- a/src/handlers/save.py +++ b/src/handlers/save.py @@ -5,9 +5,9 @@ from src.data_models.Game import Game from src.data_models.Record import Record -from src.utils import message_is_poll, is_message_from_group_chat -from src import db from src.services.db_service import save_record, save_game +from src.services.draw_result_image import draw_result_image +from src.utils import message_is_poll, is_message_from_group_chat async def _pass_checks( @@ -67,41 +67,51 @@ async def save(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: msg_with_poll = ( update.effective_message.reply_to_message ) # get a poll from reply message - if await _pass_checks(msg_with_poll, update, context): + if await _pass_checks(msg_with_poll=msg_with_poll, update=update, context=context): await context.bot.stop_poll(update.effective_chat.id, msg_with_poll.id) poll_data = context.bot_data[msg_with_poll.poll.id] - await asyncio.gather( - *[ - save_record( - Record( - creator_id=poll_data["creator_id"], - player_id=player_id, - playroom_id=poll_data["chat_id"], - game_id=poll_data["message_id"], - role=result, - ), - ) - for player_id, result in poll_data["results"].items() - ] - ) + records = [ + Record( + creator_id=poll_data["creator_id"], + player_id=player_id, + playroom_id=poll_data["chat_id"], + game_id=poll_data["message_id"], + role=result, + ) + for player_id, result in poll_data["results"].items() + ] + # await asyncio.gather(*[save_record(record) for record in records]) game = Game( poll_id=poll_data["message_id"], chat_id=poll_data["chat_id"], creator_id=poll_data["creator_id"], results=poll_data["results"].copy(), ) - await save_game(game) - await update.effective_message.reply_text( - "The Game has been saved!. Results: {}".format(game.results) - ) - # Delete the poll - await context.bot.delete_message( - chat_id=game.chat_id, - message_id=game.poll_id + # post-game tasks + await asyncio.gather( + *[ + *[save_record(record) for record in records], + save_game(game), + context.bot.delete_message( + chat_id=game.chat_id, message_id=game.poll_id + ), + update.effective_message.delete(), + context.bot.send_photo( + photo=( + await draw_result_image( + records=records, + result=game.results, + update=update, + context=context, + ) + ), + chat_id=game.chat_id, + caption="The Game has been saved!", + disable_notification=True, + ), + ] ) - # Delete this callback /save message - await update.effective_message.delete() else: await update.effective_message.reply_text( "Something went wrong. Can't process your request." diff --git a/src/services/db_service.py b/src/services/db_service.py index 5d9a422..c01f984 100644 --- a/src/services/db_service.py +++ b/src/services/db_service.py @@ -1,12 +1,6 @@ import logging import sqlite3 -from src.data_models.Playroom import Playroom -import logging -import sqlite3 - -from src.data_models.Game import Game -from src.data_models.Player import Player from src.data_models.Game import Game from src.data_models.Player import Player from src.data_models.Playroom import Playroom diff --git a/src/services/draw_result_image.py b/src/services/draw_result_image.py new file mode 100644 index 0000000..95883c3 --- /dev/null +++ b/src/services/draw_result_image.py @@ -0,0 +1,229 @@ +import asyncio +from xml.dom.minidom import parseString +from xml.etree.ElementTree import Element, SubElement, tostring + +import cairosvg +from telegram import Update +from telegram.ext import ContextTypes + +from src import db +from src.data_models.Record import Record + +LIBERAL_COLOR = "#61C8D9" +LIBERAL_COLOR_STROKE = "#38586D" +FASCIST_COLOR = "#E66443" +FASCIST_COLOR_STROKE = "#7A1E16" + +STROKE_SIZE = str(12) + + +def save_svg(svg_string, file_path): + with open(file_path, "w") as file: + file.write(svg_string) + + +def svg2png(svg_string) -> bytes: + return cairosvg.svg2png( + bytestring=svg_string, + scale=1, + output_width=1338 // 2, + output_height=926 // 2, + unsafe=True, + ) + + +def create_background(color): + svg = Element( + "svg", + width="1338", + height="926", + viewBox="0 0 1338 926", + fill=color, + rx="6", + id="background", + xmlns="http://www.w3.org/2000/svg", + ) + # background = SubElement(svg, 'rect', id="Shape", width="1338", height="926", rx="6", fill=color) + return svg + + +def create_board(svg, board_type, players): + if board_type == "Fascist": + board = SubElement(svg, "g", id="F_BOARD_GROUP") + x, y, width, height = "56", "518", "1226", "360" + fill, stroke, stroke_width = FASCIST_COLOR, FASCIST_COLOR_STROKE, STROKE_SIZE + else: # Liberal + board = SubElement(svg, "g", id="L_BOARD_GROUP") + x, y, width, height = "56", "48", "1226", "360" + fill, stroke, stroke_width = LIBERAL_COLOR, LIBERAL_COLOR_STROKE, STROKE_SIZE + + board_rect = SubElement( + board, + "rect", + id=f"{board_type.upper()}_BOARD", + x=x, + y=y, + width=width, + height=height, + rx="6", + fill=fill, + stroke=stroke, + stroke_width=stroke_width, + ) + + name_rect = SubElement( + board, + "rect", + id=f"{board_type.upper()}_BOARD_NAME", + x=x, + y=y, + width=width, + height="110", + fill=fill, + stroke=stroke, + stroke_width=str(int(stroke_width) / 2), + fill_opacity="0.6", + ) + board_name_text = SubElement( + board, + "text", + x=str(int(x) + int(width) / 2), + y=str(int(y) + 55), + font_size="60", + fill="black", + text_anchor="middle", + dominant_baseline="middle", + ) + board_name_text.text = board_type + player_x = ( + int(x) + int(width) // 2 - (len(players) * 170 + (len(players) - 1) * 30) // 2 + ) + for i, player in enumerate(players): + user_pic = SubElement( + board, + "image", + href=player["user_profile_photo"], + x=str(player_x), + y=str(int(y) + 117), + stroke=stroke, + stroke_width=str(int(stroke_width) // 2), + height="170", + width="170", + ) + username_rect = SubElement( + board, + "rect", + id=f"username_{i}", + x=str(player_x), + y=str(int(y) + 288), + stroke=stroke, + stroke_width=str(int(stroke_width) // 2), + width="170", + height="55", + fill="white", + ) + username_text = SubElement( + board, + "text", + x=str(player_x + 170 / 2), + y=str(int(y) + 288 + 55 / 2), + font_size="20", + fill="black", + text_anchor="middle", + dominant_baseline="middle", + ) + + username_text.text = player["name"] + player_x += 200 # Width 170 + 30 space + + +def create_result(svg, outcome): + x, y, width, height = "56", "414", "1226", "98" + result = SubElement( + svg, + "rect", + id="RESULT", + x=x, + y=y, + width=width, + height=height, + fill="white", + fill_opacity="0.3", + ) + result_text = SubElement( + svg, + "text", + x=str(int(x) + int(width) / 2), + y=str(int(y) + 55), + fill="black", + font_size="60", + text_anchor="middle", + dominant_baseline="middle", + ) + result_text.text = outcome + + +def fix_python_wrong_svg_string(svg_string): + """Fixes SVG string for compatibility and prettifies it.""" + replacements = { + "text_anchor": "text-anchor", + "dominant_baseline": "dominant-baseline", + "font_size": "font-size", + "stroke_width": "stroke-width", + } + for wrong, correct in replacements.items(): + svg_string = svg_string.replace(wrong, correct) + return parseString(svg_string).toprettyxml() + + +def draw_game_result(players: tuple[dict], outcome: str): + fascist_players = tuple(player for player in players if player["team"] == "Fascist") + liberal_players = tuple(player for player in players if player["team"] == "Liberal") + + svg = create_background("#363835") + create_board(svg=svg, board_type="Liberal", players=liberal_players) + create_board(svg=svg, board_type="Fascist", players=fascist_players) + create_result(svg=svg, outcome=outcome) + + svg = parseString(tostring(svg)) + svg = svg.toprettyxml() + svg = fix_python_wrong_svg_string(svg) + return svg + + +async def get_user_profile_photo(context, player_id) -> str: + photo_objects = ( + await context.bot.get_user_profile_photos(player_id, limit=1, offset=0) + ).photos[0] + return (await context.bot.get_file(photo_objects[-1])).file_path + + +async def get_player(context, record: Record): + """Get the player from the record""" + return { + "name": ( + await db.fetch_one( + """SELECT full_name FROM players WHERE id = (?)""", [record.player_id] + ) + )["full_name"], + "role": record.role, + "team": record.get_team(), + "user_profile_photo": await get_user_profile_photo(context, record.player_id), + } + + +async def draw_result_image( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + records: list[Record], + result: str, +) -> bytes: + """Send the result of the game to the chat""" + # player = namedtuple("Player", ["name", "role", "team", "user_profile_photo"]) + players = tuple( + await asyncio.gather( + *[get_player(context=context, record=record) for record in records] + ) + ) + svg = draw_game_result(players, result) + return svg2png(svg)