Skip to content

Commit

Permalink
wip(crud): added get_one_or_none method
Browse files Browse the repository at this point in the history
  • Loading branch information
jd-solanki committed Dec 16, 2024
1 parent 900adb4 commit 16dc2fd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
44 changes: 42 additions & 2 deletions examples/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, status
from pydantic import BaseModel, PositiveInt, RootModel
from sqlalchemy import String, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column

from fastapi_batteries.crud import CRUD
from fastapi_batteries.fastapi.exceptions import get_api_exception_handler
from fastapi_batteries.fastapi.exceptions.api_exception import APIException
from fastapi_batteries.fastapi.exceptions.api_exception import APIException, get_api_exception_handler
from fastapi_batteries.fastapi.middlewares import QueryCountMiddleware
from fastapi_batteries.pydantic.schemas import Paginated, PaginationOffsetLimit
from fastapi_batteries.sa.mixins import MixinId
Expand Down Expand Up @@ -131,6 +132,45 @@ async def get_users(
}


@app.get("/users/count")
async def get_users_count(
db: Annotated[AsyncSession, Depends(get_db)],
first_name: str = "",
first_name__contains: str = "",
):
select_statement = select(User)
if first_name:
select_statement = select_statement.where(User.first_name == first_name)
if first_name__contains:
select_statement = select_statement.where(User.first_name.contains(first_name__contains))

return await user_crud.count(db, select_statement=select_statement)


@app.get("/users/one")
async def get_one_user(
db: Annotated[AsyncSession, Depends(get_db)],
user_id: PositiveInt | None = None,
first_name: str = "",
first_name__contains: str = "",
):
select_statement = select(User)
if user_id:
select_statement = select_statement.where(User.id == user_id)
if first_name:
select_statement = select_statement.where(User.first_name == first_name)
if first_name__contains:
select_statement = select_statement.where(User.first_name.contains(first_name__contains))

try:
return await user_crud.get_one_or_none(db, select_statement=lambda _: select_statement)
except MultipleResultsFound as e:
raise APIException(
title="Multiple results found",
status=status.HTTP_400_BAD_REQUEST,
) from e


@app.get("/users/{user_id}")
async def get_user(user_id: PositiveInt, db: Annotated[AsyncSession, Depends(get_db)]):
return await user_crud.get_or_404(db, user_id)
Expand Down
38 changes: 37 additions & 1 deletion src/fastapi_batteries/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Callable, Sequence
from contextlib import suppress
from logging import Logger
from typing import Any, Literal, overload

from fastapi import status
from pydantic import BaseModel, RootModel
from sqlalchemy import ScalarResult, Select, delete, func, insert, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase

Expand Down Expand Up @@ -166,6 +168,7 @@ async def get_or_404(
title=msg_404 or self.err_messages[404],
)

# TODO: Instead of all columns, fetch specific columns
# TODO: Add overload for pagination conditional return
async def get_multi(
self,
Expand All @@ -174,10 +177,11 @@ async def get_multi(
pagination: PaginationPageSize | PaginationOffsetLimit | None = None,
select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s,
) -> Sequence[ModelType] | tuple[Sequence[ModelType], int]:
# --- Initialize statements
_select_statement = select_statement(select(self.model))
paginated_statement: Select[tuple[ModelType]] | None = None

# Pagination
# --- Pagination
if pagination:
if isinstance(pagination, PaginationPageSize):
offset, limit = page_size_to_offset_limit(page=pagination.page, size=pagination.size)
Expand All @@ -186,16 +190,47 @@ async def get_multi(

paginated_statement = _select_statement.limit(limit).offset(offset)

# --- Fetch records
result = await db.scalars(
paginated_statement if paginated_statement is not None else _select_statement,
)
records = result.unique().all()

# --- Return records
if pagination:
total = await self.count(db, select_statement=_select_statement)
return records, total
return records

async def get_one_or_none(
self,
db: AsyncSession,
*,
select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s,
suppress_multiple_result_exc: bool = False,
):
"""Get one item or None based on select statement.
Args:
db: SQLAlchemy AsyncSession
select_statement: Select statement to fetch the item
suppress_multiple_result_exc: Whether to suppress `MultipleResultsFound` exception
Returns:
Queried item or None
Raises:
MultipleResultsFound: If multiple results are found and `suppress_multiple_result_exc` is False
"""
result = await db.scalars(select_statement(select(self.model)))

try:
return result.unique().one_or_none()
except MultipleResultsFound:
if not suppress_multiple_result_exc:
raise

# TODO: Can we fetch TypedDict from SchemaPatch? Using `dict[str, Any]` is not good.
async def patch(
self,
Expand Down Expand Up @@ -262,6 +297,7 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) -

return result.rowcount

# TODO: Use callable for select_statement like other methods
async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int:
count_select_from = select_statement.subquery() if select_statement is not None else self.model
count_statement = select(func.count()).select_from(count_select_from)
Expand Down

0 comments on commit 16dc2fd

Please sign in to comment.