Skip to content

Commit

Permalink
load queries from a directory
Browse files Browse the repository at this point in the history
  • Loading branch information
imryche committed Jul 21, 2024
1 parent 15e80e6 commit 05c1320
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 9 deletions.
2 changes: 1 addition & 1 deletion litequery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from litequery.core import setup


__version__ = "0.1.0"
__version__ = "0.2.0"
__all__ = ["setup"]
17 changes: 15 additions & 2 deletions litequery/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import glob
import os
from enum import Enum
import re
from contextlib import asynccontextmanager
Expand All @@ -21,8 +23,8 @@ class Query:
op: Op = Op.SELECT


def parse_queries(path):
with open(path) as f:
def parse_file_queries(file_path):
with open(file_path) as f:
content = f.read()
raw_queries = re.findall(r"-- name: (.+)\n([\s\S]*?);", content)

Expand All @@ -40,7 +42,18 @@ def parse_queries(path):
args = re.findall(r":(\w+)", sql)
query = Query(name=query_name, sql=sql, args=args, op=op)
queries.append(query)
return queries


def parse_queries(path):
queries = []
if os.path.isdir(path):
for file_path in glob.glob(os.path.join(path, "*.sql")):
queries.extend(parse_file_queries(file_path))
elif os.path.isfile(path):
queries.extend(parse_file_queries(path))
else:
raise ValueError(f"Path {path} is neither a file nor a directory.")
return queries


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "litequery"
version = "0.1.0"
version = "0.2.0"
authors = [{ name = "Dima Charnyshou", email = "[email protected]" }]
description = "A handy way to interact with an SQLite database from Python"
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions tests/queries/events.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- name: get_all_events
select
*
from
events;

29 changes: 29 additions & 0 deletions tests/queries/users.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- name: get_all_users
select
*
from
users;

-- name: get_user_by_id^
select
*
from
users
where
id = :id;

-- name: get_last_user_id$
select
id
from
users
order by
id desc;

-- name: insert_user<!
insert into users (name, email)
values (:name, :email);

-- name: delete_all_users!
delete from users;

48 changes: 43 additions & 5 deletions tests/test_litequery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,47 @@
import litequery
import pytest_asyncio

from litequery.core import parse_queries

DATABASE_PATH = "users.db"
QUERIES_PATH = "tests/queries.sql"
QUERIES_FILE_PATH = "tests/queries.sql"
QUERIES_DIR_PATH = "tests/queries/"


@pytest_asyncio.fixture
async def setup_database():
async with aiosqlite.connect(DATABASE_PATH) as conn:
await conn.execute(
"create table users (id integer primary key autoincrement, name text not null, email text not null)"
"""
create table users (
id integer primary key autoincrement,
name text not null,
email text not null
)
"""
)
await conn.execute(
"""
create table events (
id integer primary key autoincrement,
user_id integer not null,
name text not null,
created_at datetime not null default current_timestamp,
foreign key (user_id) references users (id)
)
"""
)
await conn.execute(
"insert into users (id, name, email) values (1, 'Alice', '[email protected]')"
)
await conn.execute(
"insert into users (id, name, email) values (2, 'Bob', '[email protected]')"
)
await conn.execute(
"insert into users (name, email) values ('Alice', '[email protected]')"
"insert into events (user_id, name) values (1, 'user_logged_in')"
)
await conn.execute(
"insert into users (name, email) values ('Bob', '[email protected]')"
"insert into events (user_id, name) values (2, 'password_changed')"
)
await conn.commit()
yield
Expand All @@ -27,12 +53,24 @@ async def setup_database():

@pytest_asyncio.fixture
async def lq(setup_database):
lq = litequery.setup(DATABASE_PATH, QUERIES_PATH)
lq = litequery.setup(DATABASE_PATH, QUERIES_DIR_PATH)
await lq.connect()
yield lq
await lq.disconnect()


@pytest.mark.asyncio
async def test_parse_queries_from_file():
queries = parse_queries(QUERIES_FILE_PATH)
assert len(queries) == 5


@pytest.mark.asyncio
async def test_parse_queries_from_directory():
queries = parse_queries(QUERIES_DIR_PATH)
assert len(queries) == 6


@pytest.mark.asyncio
async def test_get_all_users(lq):
users = await lq.get_all_users()
Expand Down

0 comments on commit 05c1320

Please sign in to comment.