Skip to content

Commit

Permalink
add support for transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
imryche committed Jul 5, 2024
1 parent 74c6cae commit b6ed89d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
2 changes: 1 addition & 1 deletion litequery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from litequery.core import setup


__version__ = "0.0.1"
__version__ = "0.1.0"
__all__ = ["setup"]
23 changes: 21 additions & 2 deletions litequery/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
import re
from contextlib import asynccontextmanager
from dataclasses import dataclass, make_dataclass, fields
import aiosqlite

Expand Down Expand Up @@ -58,6 +59,7 @@ class Litequery:
def __init__(self, database, queries):
self._database = database
self._conn = None
self._in_transaction = False
self._create_methods(queries)

def _create_method(self, query: Query):
Expand All @@ -72,10 +74,12 @@ async def query_method(**kwargs):
row = await cur.fetchone()
return getattr(row, fields(row)[0].name) if row else None
if query.op == Op.MODIFY:
await conn.commit()
if not self._in_transaction:
await conn.commit()
return cur.rowcount
if query.op == Op.INSERT_RETURNING:
await conn.commit()
if not self._in_transaction:
await conn.commit()
return cur.lastrowid

return query_method
Expand All @@ -84,6 +88,20 @@ def _create_methods(self, queries: list[Query]):
for query in queries:
setattr(self, query.name, self._create_method(query))

@asynccontextmanager
async def transaction(self):
conn = await self.get_connection()
try:
await conn.execute("begin")
self._in_transaction = True
yield
await conn.commit()
except Exception:
await conn.rollback()
raise
finally:
self._in_transaction = False

async def connect(self):
self._conn = await aiosqlite.connect(self._database)
self._conn.row_factory = dataclass_factory
Expand All @@ -92,6 +110,7 @@ async def disconnect(self):
if self._conn is None:
return
await self._conn.close()
self._conn = None

async def get_connection(self):
if self._conn is None:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_litequery.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,27 @@ async def test_delete_all_users(lq):
assert rowcount == 2
users = await lq.get_all_users()
assert len(users) == 0


@pytest.mark.asyncio
async def test_transaction_commit(lq):
async with lq.transaction():
await lq.insert_user(name="Charlie", email="[email protected]")
await lq.insert_user(name="Eve", email="[email protected]")

users = await lq.get_all_users()
assert len(users) == 4
assert users[2].name, users[2].email == ("Charlie", "[email protected]")
assert users[3].name, users[3].email == ("Eve", "[email protected]")


@pytest.mark.asyncio
async def test_transaction_rollback(lq):
with pytest.raises(Exception):
async with lq.transaction():
await lq.insert_user(name="Charlie", email="[email protected]")
raise Exception("Force rollback")
await lq.insert_user(name="Eve", email="[email protected]")

users = await lq.get_all_users()
assert len(users) == 2

0 comments on commit b6ed89d

Please sign in to comment.