Skip to content

Commit 4d209b7

Browse files
committed
Implement support for pool connection rotation
The new `Pool.expire_connections()` method expires all currently open connections, so they would be replaced with fresh ones on the next `acquire()` attempt. The new `Pool.set_connect_args()` allows changing the connection arguments for an existing pool instance. Coupled with `expire_connections()`, it allows adapting the pool to the new environment conditions without having to replace the pool instance. Fixes: #291
1 parent a6fa7a3 commit 4d209b7

File tree

3 files changed

+183
-51
lines changed

3 files changed

+183
-51
lines changed

asyncpg/_testbase/__init__.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,29 @@ def get_connection_spec(cls, kwargs={}):
328328
conn_spec['user'] = 'postgres'
329329
return conn_spec
330330

331-
def create_pool(self, pool_class=pg_pool.Pool,
332-
connection_class=pg_connection.Connection, **kwargs):
333-
conn_spec = self.get_connection_spec(kwargs)
334-
return create_pool(loop=self.loop, pool_class=pool_class,
335-
connection_class=connection_class, **conn_spec)
336-
337331
@classmethod
338332
def connect(cls, **kwargs):
339333
conn_spec = cls.get_connection_spec(kwargs)
340334
return pg_connection.connect(**conn_spec, loop=cls.loop)
341335

336+
def setUp(self):
337+
super().setUp()
338+
self._pools = []
339+
340+
def tearDown(self):
341+
super().tearDown()
342+
for pool in self._pools:
343+
pool.terminate()
344+
self._pools = []
345+
346+
def create_pool(self, pool_class=pg_pool.Pool,
347+
connection_class=pg_connection.Connection, **kwargs):
348+
conn_spec = self.get_connection_spec(kwargs)
349+
pool = create_pool(loop=self.loop, pool_class=pool_class,
350+
connection_class=connection_class, **conn_spec)
351+
self._pools.append(pool)
352+
return pool
353+
342354

343355
class ProxiedClusterTestCase(ClusterTestCase):
344356
@classmethod

asyncpg/pool.py

Lines changed: 112 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import functools
1010
import inspect
1111
import time
12+
import warnings
1213

1314
from . import connection
1415
from . import connect_utils
@@ -92,67 +93,46 @@ def __repr__(self):
9293
class PoolConnectionHolder:
9394

9495
__slots__ = ('_con', '_pool', '_loop', '_proxy',
95-
'_connect_args', '_connect_kwargs',
96-
'_max_queries', '_setup', '_init',
96+
'_max_queries', '_setup',
9797
'_max_inactive_time', '_in_use',
98-
'_inactive_callback', '_timeout')
98+
'_inactive_callback', '_timeout',
99+
'_generation')
99100

100-
def __init__(self, pool, *, connect_args, connect_kwargs,
101-
max_queries, setup, init, max_inactive_time):
101+
def __init__(self, pool, *, max_queries, setup, max_inactive_time):
102102

103103
self._pool = pool
104104
self._con = None
105105
self._proxy = None
106106

107-
self._connect_args = connect_args
108-
self._connect_kwargs = connect_kwargs
109107
self._max_queries = max_queries
110108
self._max_inactive_time = max_inactive_time
111109
self._setup = setup
112-
self._init = init
113110
self._inactive_callback = None
114111
self._in_use = None # type: asyncio.Future
115112
self._timeout = None
113+
self._generation = None
116114

117115
async def connect(self):
118116
if self._con is not None:
119117
raise exceptions.InternalClientError(
120118
'PoolConnectionHolder.connect() called while another '
121119
'connection already exists')
122120

123-
if self._pool._working_addr is None:
124-
# First connection attempt on this pool.
125-
con = await connection.connect(
126-
*self._connect_args,
127-
loop=self._pool._loop,
128-
connection_class=self._pool._connection_class,
129-
**self._connect_kwargs)
130-
131-
self._pool._working_addr = con._addr
132-
self._pool._working_config = con._config
133-
self._pool._working_params = con._params
134-
135-
else:
136-
# We've connected before and have a resolved address,
137-
# and parsed options and config.
138-
con = await connect_utils._connect_addr(
139-
loop=self._pool._loop,
140-
addr=self._pool._working_addr,
141-
timeout=self._pool._working_params.connect_timeout,
142-
config=self._pool._working_config,
143-
params=self._pool._working_params,
144-
connection_class=self._pool._connection_class)
145-
146-
if self._init is not None:
147-
await self._init(con)
148-
149-
self._con = con
121+
self._con = await self._pool._get_new_connection()
122+
self._generation = self._pool._generation
150123

151124
async def acquire(self) -> PoolConnectionProxy:
152125
if self._con is None or self._con.is_closed():
153126
self._con = None
154127
await self.connect()
155128

129+
elif self._generation != self._pool._generation:
130+
# Connections have been expired, re-connect the holder.
131+
self._pool._loop.create_task(
132+
self._con.close(timeout=self._timeout))
133+
self._con = None
134+
await self.connect()
135+
156136
self._maybe_cancel_inactive_callback()
157137

158138
self._proxy = proxy = PoolConnectionProxy(self, self._con)
@@ -197,6 +177,13 @@ async def release(self, timeout):
197177
await self._con.close(timeout=timeout)
198178
return
199179

180+
if self._generation != self._pool._generation:
181+
# The connection has expired because it belongs to
182+
# an older generation (Pool.expire_connections() has
183+
# been called.)
184+
await self._con.close(timeout=timeout)
185+
return
186+
200187
try:
201188
budget = timeout
202189

@@ -312,9 +299,10 @@ class Pool:
312299
"""
313300

314301
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
302+
'_init', '_connect_args', '_connect_kwargs',
315303
'_working_addr', '_working_config', '_working_params',
316304
'_holders', '_initialized', '_closing', '_closed',
317-
'_connection_class')
305+
'_connection_class', '_generation')
318306

319307
def __init__(self, *connect_args,
320308
min_size,
@@ -327,6 +315,14 @@ def __init__(self, *connect_args,
327315
connection_class,
328316
**connect_kwargs):
329317

318+
if len(connect_args) > 1:
319+
warnings.warn(
320+
"Passing multiple positional arguments to asyncpg.Pool "
321+
"constructor is deprecated and will be removed in "
322+
"asyncpg 0.17.0. The non-deprecated form is "
323+
"asyncpg.Pool(<dsn>, **kwargs)",
324+
DeprecationWarning, stacklevel=2)
325+
330326
if loop is None:
331327
loop = asyncio.get_event_loop()
332328
self._loop = loop
@@ -349,6 +345,11 @@ def __init__(self, *connect_args,
349345
'max_inactive_connection_lifetime is expected to be greater '
350346
'or equal to zero')
351347

348+
if not issubclass(connection_class, connection.Connection):
349+
raise TypeError(
350+
'connection_class is expected to be a subclass of '
351+
'asyncpg.Connection, got {!r}'.format(connection_class))
352+
352353
self._minsize = min_size
353354
self._maxsize = max_size
354355

@@ -364,16 +365,17 @@ def __init__(self, *connect_args,
364365

365366
self._closing = False
366367
self._closed = False
368+
self._generation = 0
369+
self._init = init
370+
self._connect_args = connect_args
371+
self._connect_kwargs = connect_kwargs
367372

368373
for _ in range(max_size):
369374
ch = PoolConnectionHolder(
370375
self,
371-
connect_args=connect_args,
372-
connect_kwargs=connect_kwargs,
373376
max_queries=max_queries,
374377
max_inactive_time=max_inactive_connection_lifetime,
375-
setup=setup,
376-
init=init)
378+
setup=setup)
377379

378380
self._holders.append(ch)
379381
self._queue.put_nowait(ch)
@@ -409,6 +411,62 @@ async def _async__init__(self):
409411
self._initialized = True
410412
return self
411413

414+
def set_connect_args(self, dsn=None, **connect_kwargs):
415+
r"""Set the new connection arguments for this pool.
416+
417+
The new connection arguments will be used for all subsequent
418+
new connection attempts. Existing connections will remain until
419+
they expire. Use :meth:`Pool.expire_connections()
420+
<asyncpg.pool.Pool.expire_connections>` to expedite the connection
421+
expiry.
422+
423+
:param str dsn:
424+
Connection arguments specified using as a single string in
425+
the following format:
426+
``postgres://user:pass@host:port/database?option=value``.
427+
428+
:param \*\*connect_kwargs:
429+
Keyword arguments for the :func:`~asyncpg.connection.connect`
430+
function.
431+
432+
.. versionadded:: 0.16.0
433+
"""
434+
435+
self._connect_args = [dsn]
436+
self._connect_kwargs = connect_kwargs
437+
self._working_addr = None
438+
self._working_config = None
439+
self._working_params = None
440+
441+
async def _get_new_connection(self):
442+
if self._working_addr is None:
443+
# First connection attempt on this pool.
444+
con = await connection.connect(
445+
*self._connect_args,
446+
loop=self._loop,
447+
connection_class=self._connection_class,
448+
**self._connect_kwargs)
449+
450+
self._working_addr = con._addr
451+
self._working_config = con._config
452+
self._working_params = con._params
453+
454+
else:
455+
# We've connected before and have a resolved address,
456+
# and parsed options and config.
457+
con = await connect_utils._connect_addr(
458+
loop=self._loop,
459+
addr=self._working_addr,
460+
timeout=self._working_params.connect_timeout,
461+
config=self._working_config,
462+
params=self._working_params,
463+
connection_class=self._connection_class)
464+
465+
if self._init is not None:
466+
await self._init(con)
467+
468+
return con
469+
412470
async def execute(self, query: str, *args, timeout: float=None) -> str:
413471
"""Execute an SQL command (or commands).
414472
@@ -602,6 +660,16 @@ def terminate(self):
602660
ch.terminate()
603661
self._closed = True
604662

663+
async def expire_connections(self):
664+
"""Expire all currently open connections.
665+
666+
Cause all currently open connections to get replaced on the
667+
next :meth:`~asyncpg.pool.Pool.acquire()` call.
668+
669+
.. versionadded:: 0.16.0
670+
"""
671+
self._generation += 1
672+
605673
def _check_init(self):
606674
if not self._initialized:
607675
raise exceptions.InterfaceError('pool is not initialized')
@@ -708,6 +776,10 @@ def create_pool(dsn=None, *,
708776
Keyword arguments for the :func:`~asyncpg.connection.connect`
709777
function.
710778
779+
:param Connection connection_class:
780+
The class to use for connections. Must be a subclass of
781+
:class:`~asyncpg.connection.Connection`.
782+
711783
:param int min_size:
712784
Number of connection the pool will be initialized with.
713785
@@ -759,11 +831,6 @@ def create_pool(dsn=None, *,
759831
<connection.Connection.add_log_listener>`) present on the connection
760832
at the moment of its release to the pool.
761833
"""
762-
if not issubclass(connection_class, connection.Connection):
763-
raise TypeError(
764-
'connection_class is expected to be a subclass of '
765-
'asyncpg.Connection, got {!r}'.format(connection_class))
766-
767834
return Pool(
768835
dsn,
769836
connection_class=connection_class,

tests/test_pool.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,59 @@ async def worker():
794794

795795
await task
796796

797+
async def test_pool_expire_connections(self):
798+
pool = await self.create_pool(database='postgres',
799+
min_size=1, max_size=1)
800+
801+
con = await pool.acquire()
802+
try:
803+
await pool.expire_connections()
804+
finally:
805+
await pool.release(con)
806+
807+
self.assertIsNone(pool._holders[0]._con)
808+
809+
async def test_pool_set_connection_args(self):
810+
pool = await self.create_pool(database='postgres',
811+
min_size=1, max_size=1)
812+
813+
# Test that connection is expired on release.
814+
con = await pool.acquire()
815+
connspec = self.get_connection_spec()
816+
try:
817+
connspec['server_settings']['application_name'] = \
818+
'set_conn_args_test'
819+
except KeyError:
820+
connspec['server_settings'] = {
821+
'application_name': 'set_conn_args_test'
822+
}
823+
824+
pool.set_connect_args(**connspec)
825+
await pool.expire_connections()
826+
await pool.release(con)
827+
828+
con = await pool.acquire()
829+
self.assertEqual(con.get_settings().application_name,
830+
'set_conn_args_test')
831+
await pool.release(con)
832+
833+
# Test that connection is expired before acquire.
834+
connspec = self.get_connection_spec()
835+
try:
836+
connspec['server_settings']['application_name'] = \
837+
'set_conn_args_test'
838+
except KeyError:
839+
connspec['server_settings'] = {
840+
'application_name': 'set_conn_args_test_2'
841+
}
842+
843+
pool.set_connect_args(**connspec)
844+
await pool.expire_connections()
845+
846+
con = await pool.acquire()
847+
self.assertEqual(con.get_settings().application_name,
848+
'set_conn_args_test_2')
849+
797850

798851
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
799852
class TestHotStandby(tb.ClusterTestCase):

0 commit comments

Comments
 (0)