diff --git a/examples/crud.py b/examples/crud.py index 769c311..4bc45f2 100644 --- a/examples/crud.py +++ b/examples/crud.py @@ -2,15 +2,16 @@ from contextlib import asynccontextmanager from typing import Annotated -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, status from pydantic import BaseModel, PositiveInt, RootModel from sqlalchemy import String, select +from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column from fastapi_batteries.crud import CRUD from fastapi_batteries.fastapi.exceptions import get_api_exception_handler -from fastapi_batteries.fastapi.exceptions.api_exception import APIException +from fastapi_batteries.fastapi.exceptions.api_exception import APIException, get_api_exception_handler from fastapi_batteries.fastapi.middlewares import QueryCountMiddleware from fastapi_batteries.pydantic.schemas import Paginated, PaginationOffsetLimit from fastapi_batteries.sa.mixins import MixinId @@ -131,6 +132,45 @@ async def get_users( } +@app.get("/users/count") +async def get_users_count( + db: Annotated[AsyncSession, Depends(get_db)], + first_name: str = "", + first_name__contains: str = "", +): + select_statement = select(User) + if first_name: + select_statement = select_statement.where(User.first_name == first_name) + if first_name__contains: + select_statement = select_statement.where(User.first_name.contains(first_name__contains)) + + return await user_crud.count(db, select_statement=select_statement) + + +@app.get("/users/one") +async def get_one_user( + db: Annotated[AsyncSession, Depends(get_db)], + user_id: PositiveInt | None = None, + first_name: str = "", + first_name__contains: str = "", +): + select_statement = select(User) + if user_id: + select_statement = select_statement.where(User.id == user_id) + if first_name: + select_statement = select_statement.where(User.first_name == first_name) + if first_name__contains: + select_statement = select_statement.where(User.first_name.contains(first_name__contains)) + + try: + return await user_crud.get_one_or_none(db, select_statement=lambda _: select_statement) + except MultipleResultsFound as e: + raise APIException( + title="Multiple results found", + status=status.HTTP_400_BAD_REQUEST, + ) from e + + @app.get("/users/{user_id}") async def get_user(user_id: PositiveInt, db: Annotated[AsyncSession, Depends(get_db)]): return await user_crud.get_or_404(db, user_id) diff --git a/src/fastapi_batteries/crud/__init__.py b/src/fastapi_batteries/crud/__init__.py index 75ef6b6..5463f6e 100644 --- a/src/fastapi_batteries/crud/__init__.py +++ b/src/fastapi_batteries/crud/__init__.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from contextlib import suppress from logging import Logger from typing import Any, Literal, overload @@ -6,6 +7,7 @@ from pydantic import BaseModel, RootModel from sqlalchemy import ScalarResult, Select, delete, func, insert, select from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase @@ -166,6 +168,7 @@ async def get_or_404( title=msg_404 or self.err_messages[404], ) + # TODO: Instead of all columns, fetch specific columns # TODO: Add overload for pagination conditional return async def get_multi( self, @@ -174,10 +177,11 @@ async def get_multi( pagination: PaginationPageSize | PaginationOffsetLimit | None = None, select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s, ) -> Sequence[ModelType] | tuple[Sequence[ModelType], int]: + # --- Initialize statements _select_statement = select_statement(select(self.model)) paginated_statement: Select[tuple[ModelType]] | None = None - # Pagination + # --- Pagination if pagination: if isinstance(pagination, PaginationPageSize): offset, limit = page_size_to_offset_limit(page=pagination.page, size=pagination.size) @@ -186,16 +190,47 @@ async def get_multi( paginated_statement = _select_statement.limit(limit).offset(offset) + # --- Fetch records result = await db.scalars( paginated_statement if paginated_statement is not None else _select_statement, ) records = result.unique().all() + # --- Return records if pagination: total = await self.count(db, select_statement=_select_statement) return records, total return records + async def get_one_or_none( + self, + db: AsyncSession, + *, + select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s, + suppress_multiple_result_exc: bool = False, + ): + """Get one item or None based on select statement. + + Args: + db: SQLAlchemy AsyncSession + select_statement: Select statement to fetch the item + suppress_multiple_result_exc: Whether to suppress `MultipleResultsFound` exception + + Returns: + Queried item or None + + Raises: + MultipleResultsFound: If multiple results are found and `suppress_multiple_result_exc` is False + + """ + result = await db.scalars(select_statement(select(self.model))) + + try: + return result.unique().one_or_none() + except MultipleResultsFound: + if not suppress_multiple_result_exc: + raise + # TODO: Can we fetch TypedDict from SchemaPatch? Using `dict[str, Any]` is not good. async def patch( self, @@ -262,6 +297,7 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) - return result.rowcount + # TODO: Use callable for select_statement like other methods async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int: count_select_from = select_statement.subquery() if select_statement is not None else self.model count_statement = select(func.count()).select_from(count_select_from)