diff --git a/litequery/core.py b/litequery/core.py index 632bcf9..552720f 100644 --- a/litequery/core.py +++ b/litequery/core.py @@ -1,45 +1,78 @@ +from enum import Enum, auto import re -from dataclasses import make_dataclass +from dataclasses import dataclass, make_dataclass import aiosqlite +class Op(Enum): + SELECT = auto() + INSERT = auto() + + +OP_TYPES = { + "": Op.SELECT, + "!": Op.INSERT, +} + + +@dataclass +class Query: + name: str + sql: str + args: list + op: Op = Op.SELECT + + def parse_queries(path): with open(path) as f: content = f.read() - return dict(re.findall(r"-- name: (\w+)\n(.*?);", content, re.DOTALL)) + raw_queries = re.findall(r"-- name: (.+)\n([\s\S]*?);", content) + queries = [] + for query_name, sql in raw_queries: + match = re.match(r"^([a-z_][a-z0-9_-]*)([!]?)$", query_name) + if not match: + raise NameError(f'Invalid query name: "{query_name}"') + query_name = match.group(1) + op_symbol = match.group(2) + op = OP_TYPES.get(op_symbol, Op.SELECT) + + args = re.findall(r":(\w+)", sql) + query = Query(name=query_name, sql=sql, args=args, op=op) + queries.append(query) + return queries def create(database, queries_path): queries = parse_queries(queries_path) - return LiteQuery(database, queries) + return Litequery(database, queries) def dataclass_factory(cursor, row): fields = [col[0] for col in cursor.description] - cls = make_dataclass("Row", fields) + cls = make_dataclass("Record", fields) return cls(*row) -class LiteQuery: +class Litequery: def __init__(self, database, queries): self._database = database - self._queries = queries self._conn = None - self._create_methods() + self._create_methods(queries) - def _create_method(self, name): - async def query_method(): - query = self._queries[name] + def _create_method(self, query: Query): + async def query_method(**kwargs): conn = await self.conn() - async with conn.execute(query) as cursor: - rows = await cursor.fetchall() - return rows + async with conn.execute(query.sql, kwargs) as cur: + if query.op == Op.SELECT: + return await cur.fetchall() + if query.op == Op.INSERT: + return cur.rowcount return query_method - def _create_methods(self): - for name, query in self._queries.items(): - setattr(self, name, self._create_method(name)) + def _create_methods(self, queries: list[Query]): + for query in queries: + setattr(self, query.name, self._create_method(query)) async def connect(self): self._conn = await aiosqlite.connect(self._database) diff --git a/queries.sql b/queries.sql index c012753..7e674c1 100644 --- a/queries.sql +++ b/queries.sql @@ -13,3 +13,10 @@ order by id desc limit 1; +-- name: users_insert! +insert into users (name) + values (:name); + +-- name: users_delete_all +delete from users; + diff --git a/tests/test_litequery.py b/tests/test_litequery.py index e29aa14..7ac2d69 100644 --- a/tests/test_litequery.py +++ b/tests/test_litequery.py @@ -3,9 +3,14 @@ @pytest.mark.asyncio -async def test_idea(): +async def test_queries(): lq = litequery.create("users.db", "queries.sql") await lq.connect() + print(await lq.users_delete_all()) + print(await lq.users_insert(name="kocia")) + print(await lq.users_insert(name="kot")) + print(await lq.users_insert(name="simba")) + print(await lq.users_insert(name="bunchik")) print(await lq.users_all()) print(await lq.users_first()) await lq.disconnect() diff --git a/users.db b/users.db index 4bb941c..e4401fb 100644 Binary files a/users.db and b/users.db differ