Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f7ad73d

Browse files
committed
Add middleware support
1 parent 32fccaa commit f7ad73d

File tree

5 files changed

+59
-12
lines changed

5 files changed

+59
-12
lines changed

‎asyncpg/connect_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
505505

506506

507507
async def _connect_addr(*, addr, loop, timeout, params, config,
508-
connection_class):
508+
middlewares, connection_class):
509509
assert loop is not None
510510

511511
if timeout <= 0:
@@ -539,12 +539,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
539539
tr.close()
540540
raise
541541

542-
con = connection_class(pr, tr, loop, addr, config, params)
542+
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
543543
pr.set_connection(con)
544544
return con
545545

546546

547-
async def _connect(*, loop, timeout, connection_class, **kwargs):
547+
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
548548
if loop is None:
549549
loop = asyncio.get_event_loop()
550550

@@ -558,6 +558,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
558558
con = await _connect_addr(
559559
addr=addr, loop=loop, timeout=timeout,
560560
params=params, config=config,
561+
middlewares=middlewares,
561562
connection_class=connection_class)
562563
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
563564
last_error = ex

‎asyncpg/connection.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Connection(metaclass=ConnectionMeta):
4242
"""
4343

4444
__slots__ = ('_protocol', '_transport', '_loop',
45-
'_top_xact', '_aborted',
45+
'_top_xact', '_aborted','_middlewares'
4646
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
4747
'_listeners', '_server_version', '_server_caps',
4848
'_intro_query', '_reset_query', '_proxy',
@@ -53,7 +53,8 @@ class Connection(metaclass=ConnectionMeta):
5353
def __init__(self, protocol, transport, loop,
5454
addr: (str, int) or str,
5555
config: connect_utils._ClientConfiguration,
56-
params: connect_utils._ConnectionParameters):
56+
params: connect_utils._ConnectionParameters,
57+
middlewares=None):
5758
self._protocol = protocol
5859
self._transport = transport
5960
self._loop = loop
@@ -92,7 +93,7 @@ def __init__(self, protocol, transport, loop,
9293

9394
self._reset_query = None
9495
self._proxy = None
95-
96+
self._middlewares=_middlewares
9697
# Used to serialize operations that might involve anonymous
9798
# statements. Specifically, we want to make the following
9899
# operation atomic:
@@ -1410,8 +1411,12 @@ async def reload_schema_state(self):
14101411

14111412
async def _execute(self, query, args, limit, timeout, return_status=False):
14121413
with self._stmt_exclusive_section:
1413-
result, _ = await self.__execute(
1414-
query, args, limit, timeout, return_status=return_status)
1414+
wrapped = self.__execute
1415+
if self._middlewares:
1416+
for m in reversed(self._middlewares):
1417+
wrapped = await m(self, wrapped)
1418+
1419+
result, _ = await wrapped(query, args, limit, timeout, return_status=return_status)
14151420
return result
14161421

14171422
async def __execute(self, query, args, limit, timeout,
@@ -1502,6 +1507,7 @@ async def connect(dsn=None, *,
15021507
max_cacheable_statement_size=1024 * 15,
15031508
command_timeout=None,
15041509
ssl=None,
1510+
middlewares=None,
15051511
connection_class=Connection,
15061512
server_settings=None):
15071513
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1618,6 +1624,10 @@ async def connect(dsn=None, *,
16181624
PostgreSQL documentation for
16191625
a `list of supported options <server settings>`_.
16201626
1627+
:param middlewares:
1628+
An optional list of middleware functions. Refer to documentation
1629+
on create_pool.
1630+
16211631
:param Connection connection_class:
16221632
Class of the returned connection object. Must be a subclass of
16231633
:class:`~asyncpg.connection.Connection`.
@@ -1683,6 +1693,7 @@ async def connect(dsn=None, *,
16831693
ssl=ssl, database=database,
16841694
server_settings=server_settings,
16851695
command_timeout=command_timeout,
1696+
middlewares=middlewares,
16861697
statement_cache_size=statement_cache_size,
16871698
max_cached_statement_lifetime=max_cached_statement_lifetime,
16881699
max_cacheable_statement_size=max_cacheable_statement_size)

‎asyncpg/pool.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ class Pool:
304304
Pools are created by calling :func:`~asyncpg.pool.create_pool`.
305305
"""
306306

307-
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
307+
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize','_middlewares'
308308
'_init', '_connect_args', '_connect_kwargs',
309309
'_working_addr', '_working_config', '_working_params',
310310
'_holders', '_initialized', '_initializing', '_closing',
@@ -317,6 +317,7 @@ def __init__(self, *connect_args,
317317
max_inactive_connection_lifetime,
318318
setup,
319319
init,
320+
middlewares,
320321
loop,
321322
connection_class,
322323
**connect_kwargs):
@@ -374,6 +375,7 @@ def __init__(self, *connect_args,
374375
self._closed = False
375376
self._generation = 0
376377
self._init = init
378+
self._middlewares = middlewares
377379
self._connect_args = connect_args
378380
self._connect_kwargs = connect_kwargs
379381

@@ -460,6 +462,7 @@ async def _get_new_connection(self):
460462
*self._connect_args,
461463
loop=self._loop,
462464
connection_class=self._connection_class,
465+
middlewares=self._middlewares,
463466
**self._connect_kwargs)
464467

465468
self._working_addr = con._addr
@@ -774,14 +777,14 @@ def __await__(self):
774777
self.done = True
775778
return self.pool._acquire(self.timeout).__await__()
776779

777-
778780
def create_pool(dsn=None, *,
779781
min_size=10,
780782
max_size=10,
781783
max_queries=50000,
782784
max_inactive_connection_lifetime=300.0,
783785
setup=None,
784786
init=None,
787+
middlewares=None,
785788
loop=None,
786789
connection_class=connection.Connection,
787790
**connect_kwargs):
@@ -857,6 +860,19 @@ def create_pool(dsn=None, *,
857860
or :meth:`Connection.set_type_codec() <\
858861
asyncpg.connection.Connection.set_type_codec>`.
859862
863+
:param middlewares:
864+
A list of middleware functions to be middleware just
865+
before a connection excecutes a statement.
866+
Syntax of a middleware is as follows:
867+
async def middleware_factory(connection, handler):
868+
async def middleware(query, args. limit, timeout, return_status):
869+
print('do something before')
870+
result, stmt = await handler(query, args, limit,
871+
timeout, return_status)
872+
print('do something after')
873+
return result, stmt
874+
return middleware
875+
860876
:param loop:
861877
An asyncio event loop instance. If ``None``, the default
862878
event loop will be used.
@@ -884,6 +900,8 @@ def create_pool(dsn=None, *,
884900
dsn,
885901
connection_class=connection_class,
886902
min_size=min_size, max_size=max_size,
887-
max_queries=max_queries, loop=loop, setup=setup, init=init,
903+
max_queries=max_queries, loop=loop, setup=setup,
904+
middlewares=middlewares, init=init,
888905
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
889906
**connect_kwargs)
907+

‎docs/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
2929
* CPython header files. These can usually be obtained by installing
3030
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
3131
**python3-devel** on RHEL/Fedora.
32-
32+
* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)
3333
Once the above requirements are satisfied, run the following command
3434
in the root of the source checkout:
3535

‎tests/test_pool.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ async def worker():
7676
tasks = [worker() for _ in range(n)]
7777
await asyncio.gather(*tasks, loop=self.loop)
7878

79+
async def test_pool_with_middleware(self):
80+
called = False
81+
82+
async def my_middleware_factory(connection, handler):
83+
async def middleware(query, args, limit, timeout, return_status):
84+
nonlocal called
85+
called = True
86+
return await handler(query, args, limit,
87+
timeout, return_status)
88+
async with self.create_pool(database='postgres',
89+
min_size=1, max_size=1,
90+
middlewares=[my_middleware_factory]) \
91+
as pool:
92+
con = await pool.acquire(timeout=5)
93+
await con.fetchval('SELECT 1', 1)
94+
assert called
95+
7996
async def test_pool_03(self):
8097
pool = await self.create_pool(database='postgres',
8198
min_size=1, max_size=1)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /