diff --git a/docs/src/piccolo/query_clauses/index.rst b/docs/src/piccolo/query_clauses/index.rst index f9167ff63..60a539cfc 100644 --- a/docs/src/piccolo/query_clauses/index.rst +++ b/docs/src/piccolo/query_clauses/index.rst @@ -29,6 +29,7 @@ by modifying the return values. ./on_conflict ./output ./returning + ./lock_for .. toctree:: :maxdepth: 1 diff --git a/docs/src/piccolo/query_clauses/lock_for.rst b/docs/src/piccolo/query_clauses/lock_for.rst new file mode 100644 index 000000000..f7a855e95 --- /dev/null +++ b/docs/src/piccolo/query_clauses/lock_for.rst @@ -0,0 +1,93 @@ +.. _lock_for: + +lock_for +======== + +You can use ``lock_for`` clauses with the following queries: + +* :ref:`Objects` +* :ref:`Select` + +Returns a query that locks rows until the end of the transaction, generating a SELECT ... FOR UPDATE SQL statement or +similar with other lock strengths. + +.. note:: Postgres and CockroachDB only. + +------------------------------------------------------------------------------- + +Basic usage without parameters: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_for() + +Equivalent to: + +.. code-block:: sql + + SELECT ... FOR UPDATE + + +lock_strength +------------- + +The parameter ``lock_strength`` controls the strength of the row lock when performing an operation in PostgreSQL. +The value can be a predefined constant from the ``LockStrength`` enum or one of the following strings (case-insensitive): + + - ``UPDATE`` (default): Acquires an exclusive lock on the selected rows, preventing other transactions from modifying or locking them until the current transaction is complete. + - ``NO KEY UPDATE`` (Postgres only): Similar to UPDATE, but allows other transactions to insert or delete rows that do not affect the primary key or unique constraints. + - ``KEY SHARE`` (Postgres only): Permits other transactions to acquire key-share or share locks, allowing non-key modifications while preventing updates or deletes. + - ``SHARE``: Acquires a shared lock, allowing other transactions to read the rows but not modify or lock them. + + +You can specify a different lock strength: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_for('share') + +Which is equivalent to: + +.. code-block:: sql + + SELECT ... FOR SHARE + + +nowait +------ + +If another transaction has already acquired a lock on one or more selected rows, an exception will be raised instead of +waiting for the other transaction to release the lock. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_for('update', nowait=True) + + +skip_locked +----------- + +Ignore locked rows. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_for('update', skip_locked=True) + + + +of +-- + +By default, if there are many tables in a query (e.g., when joining), all tables will be locked. +Using ``of``, you can specify which tables should be locked. + +.. code-block:: python + + await Band.select().where(Band.manager.name == 'Guido').lock_for('update', of=(Band, )) + + +Learn more +---------- + +* `Postgres docs `_ +* `CockroachDB docs `_ diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 7f2b5aaed..8cd66b408 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -13,6 +13,8 @@ CallbackDelegate, CallbackType, LimitDelegate, + LockForDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, OrderByRaw, @@ -27,6 +29,7 @@ if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column + from piccolo.table import Table ############################################################################### @@ -194,6 +197,7 @@ class Objects( "callback_delegate", "prefetch_delegate", "where_delegate", + "lock_for_delegate", ) def __init__( @@ -213,6 +217,7 @@ def __init__( self.prefetch_delegate = PrefetchDelegate() self.prefetch(*prefetch) self.where_delegate = WhereDelegate() + self.lock_for_delegate = LockForDelegate() def output(self: Self, load_json: bool = False) -> Self: self.output_delegate.output( @@ -272,6 +277,28 @@ def first(self) -> First[TableInstance]: self.limit_delegate.limit(1) return First[TableInstance](query=self) + def lock_for( + self: Self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + "update", + "no key update", + "key share", + "share", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: t.Tuple[type[Table], ...] = (), + ) -> Self: + self.lock_for_delegate.lock_for(lock_strength, nowait, skip_locked, of) + return self + def get(self, where: Combinable) -> Get[TableInstance]: self.where_delegate.where(where) self.limit_delegate.limit(1) @@ -322,6 +349,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: "offset_delegate", "output_delegate", "order_by_delegate", + "lock_for_delegate", ): setattr(select, attr, getattr(self, attr)) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 0c590918b..4d5788938 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -19,6 +19,8 @@ DistinctDelegate, GroupByDelegate, LimitDelegate, + LockForDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, OrderByRaw, @@ -150,6 +152,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]): "output_delegate", "callback_delegate", "where_delegate", + "lock_for_delegate", ) def __init__( @@ -174,6 +177,7 @@ def __init__( self.output_delegate = OutputDelegate() self.callback_delegate = CallbackDelegate() self.where_delegate = WhereDelegate() + self.lock_for_delegate = LockForDelegate() self.columns(*columns_list) @@ -219,6 +223,28 @@ def offset(self: Self, number: int) -> Self: self.offset_delegate.offset(number) return self + def lock_for( + self: Self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + "update", + "no key update", + "key share", + "share", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: t.Tuple[type[Table], ...] = (), + ) -> Self: + self.lock_for_delegate.lock_for(lock_strength, nowait, skip_locked, of) + return self + async def _splice_m2m_rows( self, response: t.List[t.Dict[str, t.Any]], @@ -618,6 +644,14 @@ def default_querystrings(self) -> t.Sequence[QueryString]: query += "{}" args.append(self.offset_delegate._offset.querystring) + if engine_type == "sqlite" and self.lock_for_delegate._lock_for: + raise NotImplementedError( + "SQLite doesn't support SELECT .. FOR UPDATE" + ) + + if self.lock_for_delegate._lock_for: + args.append(self.lock_for_delegate._lock_for.querystring) + querystring = QueryString(query, *args) return [querystring] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index d9d5f84ca..b11946cb7 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -784,3 +784,93 @@ def on_conflict( target=target, action=action_, values=values, where=where ) ) + + +class LockStrength(str, Enum): + """ + Specify lock strength + + https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE + """ + + update = "UPDATE" + no_key_update = "NO KEY UPDATE" + share = "SHARE" + key_share = "KEY SHARE" + + +@dataclass +class LockFor: + __slots__ = ("lock_strength", "nowait", "skip_locked", "of") + + lock_strength: LockStrength + nowait: bool + skip_locked: bool + of: tuple[type[Table], ...] + + def __post_init__(self): + if not isinstance(self.lock_strength, LockStrength): + raise TypeError("lock_strength must be a LockStrength") + if not isinstance(self.nowait, bool): + raise TypeError("nowait must be a bool") + if not isinstance(self.skip_locked, bool): + raise TypeError("skip_locked must be a bool") + if not isinstance(self.of, tuple) or not all( + hasattr(x, "_meta") for x in self.of + ): + raise TypeError("of must be a tuple of Table") + if self.nowait and self.skip_locked: + raise TypeError( + "The nowait option cannot be used with skip_locked" + ) + + @property + def querystring(self) -> QueryString: + sql = f" FOR {self.lock_strength.value}" + if self.of: + tables = ", ".join(x._meta.tablename for x in self.of) + sql += " OF " + tables + if self.nowait: + sql += " NOWAIT" + if self.skip_locked: + sql += " SKIP LOCKED" + + return QueryString(sql) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class LockForDelegate: + + _lock_for: t.Optional[LockFor] = None + + def lock_for( + self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + "update", + "no key update", + "key share", + "share", + ], + ] = LockStrength.update, + nowait=False, + skip_locked=False, + of: t.Tuple[type[Table], ...] = (), + ): + lock_strength_: LockStrength + if isinstance(lock_strength, LockStrength): + lock_strength_ = lock_strength + elif isinstance(lock_strength, str): + lock_strength_ = LockStrength(lock_strength.upper()) + else: + raise ValueError("Unrecognised `lock_strength` value.") + + self._lock_for = LockFor(lock_strength_, nowait, skip_locked, of) diff --git a/tests/table/test_select.py b/tests/table/test_select.py index ebf2c3ff8..0e2ac2396 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -1028,6 +1028,40 @@ def test_select_raw(self): response, [{"name": "Pythonistas", "popularity_log": 3.0}] ) + @pytest.mark.skipif( + is_running_sqlite(), + reason="SQLite doesn't support SELECT .. FOR UPDATE.", + ) + def test_lock_for(self): + """ + Make sure the for_update clause works. + """ + self.insert_rows() + + query = Band.select() + self.assertNotIn("FOR UPDATE", query.__str__()) + + query = query.lock_for() + self.assertTrue(query.__str__().endswith("FOR UPDATE")) + + query = query.lock_for(lock_strength="key share") + self.assertTrue(query.__str__().endswith("FOR KEY SHARE")) + + query = query.lock_for(skip_locked=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED")) + + query = query.lock_for(nowait=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT")) + + query = query.lock_for(of=(Band,)) + self.assertTrue(query.__str__().endswith("FOR UPDATE OF band")) + + with self.assertRaises(TypeError): + query = query.lock_for(skip_locked=True, nowait=True) + + response = query.run_sync() + assert response is not None + class TestSelectSecret(TestCase): def setUp(self):