Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1073 Add update_self method #1081

Merged
merged 6 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions piccolo/query/methods/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if t.TYPE_CHECKING: # pragma: no cover
from piccolo.columns import Column
from piccolo.table import Table


###############################################################################
Expand Down Expand Up @@ -173,6 +174,93 @@ def run_sync(self, *args, **kwargs) -> TableInstance:
return run_sync(self.run(*args, **kwargs))


class UpdateSelf:
"""
This allows the user to update a single object - useful when the values
are derived from the database in some way.

For example, if we have the following table::

class Concert(Table):
name = Varchar(unique=True)
tickets_available = Integer()

And we fetch an object::

>>> concert = await Concert.objects().get(name="Amazing concert")

We could use the typical syntax for updating the object::

>>> concert.tickets_available += -1
>>> await concert.save()
dantownsend marked this conversation as resolved.
Show resolved Hide resolved

The problem with this, is what if another object has already decremented
``tickets_available``? It would overide the value.

Instead we can do this:

>>> await concert.update_self({
... Concert.tickets_available: Concert.tickets_available - 1
... })

This updates ``tickets_available`` in the database, and also sets the
new value for ``tickets_available`` on the object.

"""

def __init__(
self,
row: Table,
values: t.Dict[t.Union[Column, str], t.Any],
):
self.row = row
self.values = values

async def run(
self,
node: t.Optional[str] = None,
in_pool: bool = True,
) -> None:
if not self.row._exists_in_db:
raise ValueError("This row doesn't exist in the database.")

TableClass = self.row.__class__

primary_key = TableClass._meta.primary_key
primary_key_value = getattr(self.row, primary_key._meta.name)

if primary_key_value is None:
raise ValueError("The primary key is None")

columns = [
TableClass._meta.get_column_by_name(i) if isinstance(i, str) else i
for i in self.values.keys()
]

response = (
await TableClass.update(self.values)
.where(primary_key == primary_key_value)
.returning(*columns)
.run(
node=node,
in_pool=in_pool,
)
)

for key, value in response[0].items():
setattr(self.row, key, value)

def __await__(self) -> t.Generator[None, None, None]:
"""
If the user doesn't explicity call .run(), proxy to it as a
convenience.
"""
return self.run().__await__()

def run_sync(self, *args, **kwargs) -> None:
return run_sync(self.run(*args, **kwargs))


###############################################################################


Expand Down
7 changes: 6 additions & 1 deletion piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from piccolo.query.methods.create_index import CreateIndex
from piccolo.query.methods.indexes import Indexes
from piccolo.query.methods.objects import First
from piccolo.query.methods.objects import First, UpdateSelf
from piccolo.query.methods.refresh import Refresh
from piccolo.querystring import QueryString
from piccolo.utils import _camel_to_snake
Expand Down Expand Up @@ -525,6 +525,11 @@ def save(
== getattr(self, self._meta.primary_key._meta.name)
)

def update_self(
self, values: t.Dict[t.Union[Column, str], t.Any]
) -> UpdateSelf:
return UpdateSelf(row=self, values=values)

def remove(self) -> Delete:
"""
A proxy to a delete query.
Expand Down
27 changes: 27 additions & 0 deletions tests/table/test_update_self.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from piccolo.testing.test_case import AsyncTableTest
from tests.example_apps.music.tables import Band, Manager


class TestUpdateSelf(AsyncTableTest):

tables = [Band, Manager]

async def test_update_self(self):
band = Band({Band.name: "Pythonistas", Band.popularity: 1000})

# Make sure we get a ValueError if it's not in the database yet.
with self.assertRaises(ValueError):
await band.update_self({Band.popularity: Band.popularity + 1})

# Save it, so it's in the database
await band.save()

# Make sure we can successfully update the object
await band.update_self({Band.popularity: Band.popularity + 1})

# Make sure the value was updated on the object
assert band.popularity == 1001

# Make sure the value was updated in the database
await band.refresh()
assert band.popularity == 1001
Loading