Skip to content

Commit

Permalink
perf(crud): combined create & create_multi & reduced the query count …
Browse files Browse the repository at this point in the history
…from 2 to 1
  • Loading branch information
jd-solanki committed Dec 16, 2024
1 parent bb7dd15 commit 9c21c5e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 34 deletions.
35 changes: 35 additions & 0 deletions .vscode/docstring-template-google.mustach
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{{! Google Docstring Template }}
{{summaryPlaceholder}}

{{extendedSummaryPlaceholder}}
{{#parametersExist}}

Args:
{{#args}}
{{var}}: {{descriptionPlaceholder}}
{{/args}}
{{#kwargs}}
{{var}}: {{descriptionPlaceholder}}.
{{/kwargs}}
{{/parametersExist}}
{{#exceptionsExist}}

Raises:
{{#exceptions}}
{{type}}: {{descriptionPlaceholder}}
{{/exceptions}}
{{/exceptionsExist}}
{{#returnsExist}}

Returns:
{{#returns}}
{{descriptionPlaceholder}}
{{/returns}}
{{/returnsExist}}
{{#yieldsExist}}

Yields:
{{#yields}}
{{descriptionPlaceholder}}
{{/yields}}
{{/yieldsExist}}
9 changes: 7 additions & 2 deletions examples/crud.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import Depends, FastAPI
from pydantic import BaseModel, PositiveInt
from pydantic import BaseModel, PositiveInt, RootModel
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
Expand Down Expand Up @@ -80,6 +80,11 @@ async def create_user(user: UserCreate, db: Annotated[AsyncSession, Depends(get_
return await user_crud.create(db, user)


@app.post("/users/multi")
async def create_users(users: RootModel[Sequence[UserCreate]], db: Annotated[AsyncSession, Depends(get_db)]):
return await user_crud.create(db, users)


@app.get("/users/")
async def get_users(db: Annotated[AsyncSession, Depends(get_db)]):
return await user_crud.get_multi(db)
Expand Down
114 changes: 82 additions & 32 deletions src/fastapi_batteries/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Callable, Sequence
from logging import Logger
from typing import Any
from typing import Any, Literal, overload

from fastapi import HTTPException, status
from fastapi import status
from pydantic import BaseModel, RootModel
from sqlalchemy import Select, delete, select
from sqlalchemy import ScalarResult, Select, delete, insert, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
Expand Down Expand Up @@ -36,50 +36,100 @@ def __init__(
}
self.logger = logger

# TODO: Add proper types for refresh_kwargs
@overload
async def create(
self,
db: AsyncSession,
new_item: SchemaCreate,
new_data: SchemaCreate,
*,
refresh_kwargs: dict[str, Any] | None = None,
commit: bool = True,
) -> ModelType:
# ! Don't use `jsonable_encoder`` because it can cause issue like converting datetime to string.
# Converting date to string will cause error when inserting to database.
item_db = self.model(**new_item.model_dump())
db.add(item_db)
returning: Literal[True] = True,
) -> ModelType: ...

if commit:
await db.commit()
await db.refresh(item_db, **(refresh_kwargs or {}))
elif refresh_kwargs and self.logger:
self.logger.warning("refresh_kwargs is ignored because commit is False")
@overload
async def create(
self,
db: AsyncSession,
new_data: SchemaCreate,
*,
commit: bool = True,
returning: Literal[False],
) -> None: ...

return item_db
# ---

# TODO: Check how many insert statements gets executed when inserting multiple items.
async def create_multi(
@overload
async def create(
self,
db: AsyncSession,
new_items: Sequence[SchemaCreate],
new_data: RootModel[Sequence[SchemaCreate]],
*,
refresh_kwargs: dict[str, Any] | None = None,
commit: bool = True,
) -> Sequence[ModelType]:
items_db = [self.model(**new_item.model_dump()) for new_item in new_items]
db.add_all(items_db)
returning: Literal[True] = True,
) -> Sequence[ModelType]: ...

if commit:
await db.commit()
@overload
async def create(
self,
db: AsyncSession,
new_data: RootModel[Sequence[SchemaCreate]],
*,
commit: bool = True,
returning: Literal[False],
) -> None: ...

async def create(
self,
db: AsyncSession,
new_data: SchemaCreate | RootModel[Sequence[SchemaCreate]],
*,
commit: bool = True,
returning: bool = True,
) -> Sequence[ModelType] | ModelType | None:
"""Create single or multiple items using insert statement.
Args:
db: SQLAlchemy AsyncSession
new_data: New data to insert in the database
commit: Whether to commit the transaction
returning: Whether to return the inserted item(s) via `returning` clause
Returns:
Inserted item(s) if `returning` is True else None
"""
# ! Don't use `jsonable_encoder`` because it can cause issue like converting datetime to string.
# Converting date to string will cause error when inserting to database.
statement = insert(self.model).values(new_data.model_dump())

if returning:
statement = statement.returning(self.model)

# If multiple items are provided use `scalars` else use `scalar`
if isinstance(new_data, RootModel):
result = await db.scalars(statement)
else:
result = await db.scalar(statement)

if commit:
await db.commit()

# TODO: Improve this to perform all await db.refresh in simultaneously
for item_db in items_db:
await db.refresh(item_db, **(refresh_kwargs or {}))
elif refresh_kwargs and self.logger:
self.logger.warning("refresh_kwargs is ignored because commit is False")
"""
If result is `ScalarResult` then we need to call `all()` to get the list of items.
If result is not `ScalarResult` then we can return the result as it is.
return items_db
NOTE: We can determine same thing via `isinstance(new_data, RootModel)`
but mypy won't be aware of result type.
"""
if isinstance(result, ScalarResult):
return result.all()
return result

# If returning is False
await db.execute(statement)
if commit:
await db.commit()
return None

async def get(
self,
Expand Down

0 comments on commit 9c21c5e

Please sign in to comment.