Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 2ff5dad

Browse files
committed
Add middleware support
1 parent 851d586 commit 2ff5dad

File tree

6 files changed

+110
-12
lines changed

6 files changed

+110
-12
lines changed

‎asyncpg/_testbase/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def create_pool(dsn=None, *,
264264
setup=None,
265265
init=None,
266266
loop=None,
267+
middlewares=None,
267268
pool_class=pg_pool.Pool,
268269
connection_class=pg_connection.Connection,
269270
**connect_kwargs):
@@ -272,7 +273,7 @@ def create_pool(dsn=None, *,
272273
min_size=min_size, max_size=max_size,
273274
max_queries=max_queries, loop=loop, setup=setup, init=init,
274275
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
275-
connection_class=connection_class,
276+
connection_class=connection_class,middlewares=middlewares,
276277
**connect_kwargs)
277278

278279

‎asyncpg/connect_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
594594

595595

596596
async def _connect_addr(*, addr, loop, timeout, params, config,
597-
connection_class):
597+
middlewares, connection_class):
598598
assert loop is not None
599599

600600
if timeout <= 0:
@@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
633633
tr.close()
634634
raise
635635

636-
con = connection_class(pr, tr, loop, addr, config, params)
636+
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
637637
pr.set_connection(con)
638638
return con
639639

640640

641-
async def _connect(*, loop, timeout, connection_class, **kwargs):
641+
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
642642
if loop is None:
643643
loop = asyncio.get_event_loop()
644644

@@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
652652
con = await _connect_addr(
653653
addr=addr, loop=loop, timeout=timeout,
654654
params=params, config=config,
655+
middlewares=middlewares,
655656
connection_class=connection_class)
656657
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
657658
last_error = ex

‎asyncpg/connection.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
4141
"""
4242

4343
__slots__ = ('_protocol', '_transport', '_loop',
44-
'_top_xact', '_aborted',
44+
'_top_xact', '_aborted','_middlewares',
4545
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
4646
'_listeners', '_server_version', '_server_caps',
4747
'_intro_query', '_reset_query', '_proxy',
@@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta):
5252
def __init__(self, protocol, transport, loop,
5353
addr: (str, int) or str,
5454
config: connect_utils._ClientConfiguration,
55-
params: connect_utils._ConnectionParameters):
55+
params: connect_utils._ConnectionParameters,
56+
_middlewares=None):
5657
self._protocol = protocol
5758
self._transport = transport
5859
self._loop = loop
@@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop,
9192

9293
self._reset_query = None
9394
self._proxy = None
94-
95+
self._middlewares=_middlewares
9596
# Used to serialize operations that might involve anonymous
9697
# statements. Specifically, we want to make the following
9798
# operation atomic:
@@ -1399,8 +1400,13 @@ async def reload_schema_state(self):
13991400

14001401
async def _execute(self, query, args, limit, timeout, return_status=False):
14011402
with self._stmt_exclusive_section:
1402-
result, _ = await self.__execute(
1403-
query, args, limit, timeout, return_status=return_status)
1403+
wrapped = self.__execute
1404+
if self._middlewares:
1405+
for m in reversed(self._middlewares):
1406+
wrapped = await m(connection=self, handler=wrapped)
1407+
1408+
result, _ = await wrapped(query, args, limit,
1409+
timeout, return_status=return_status)
14041410
return result
14051411

14061412
async def __execute(self, query, args, limit, timeout,
@@ -1491,6 +1497,7 @@ async def connect(dsn=None, *,
14911497
max_cacheable_statement_size=1024 * 15,
14921498
command_timeout=None,
14931499
ssl=None,
1500+
middlewares=None,
14941501
connection_class=Connection,
14951502
server_settings=None):
14961503
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1607,6 +1614,10 @@ async def connect(dsn=None, *,
16071614
PostgreSQL documentation for
16081615
a `list of supported options <server settings>`_.
16091616
1617+
:param middlewares:
1618+
An optional list of middleware functions. Refer to documentation
1619+
on create_pool.
1620+
16101621
:param Connection connection_class:
16111622
Class of the returned connection object. Must be a subclass of
16121623
:class:`~asyncpg.connection.Connection`.
@@ -1672,6 +1683,7 @@ async def connect(dsn=None, *,
16721683
ssl=ssl, database=database,
16731684
server_settings=server_settings,
16741685
command_timeout=command_timeout,
1686+
middlewares=middlewares,
16751687
statement_cache_size=statement_cache_size,
16761688
max_cached_statement_lifetime=max_cached_statement_lifetime,
16771689
max_cacheable_statement_size=max_cacheable_statement_size)

‎asyncpg/pool.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class Pool:
305305
"""
306306

307307
__slots__ = (
308-
'_queue', '_loop', '_minsize', '_maxsize',
308+
'_queue', '_loop', '_minsize', '_maxsize','_middlewares',
309309
'_init', '_connect_args', '_connect_kwargs',
310310
'_working_addr', '_working_config', '_working_params',
311311
'_holders', '_initialized', '_initializing', '_closing',
@@ -320,6 +320,7 @@ def __init__(self, *connect_args,
320320
max_inactive_connection_lifetime,
321321
setup,
322322
init,
323+
middlewares,
323324
loop,
324325
connection_class,
325326
**connect_kwargs):
@@ -377,6 +378,7 @@ def __init__(self, *connect_args,
377378
self._closed = False
378379
self._generation = 0
379380
self._init = init
381+
self._middlewares = middlewares
380382
self._connect_args = connect_args
381383
self._connect_kwargs = connect_kwargs
382384

@@ -469,6 +471,7 @@ async def _get_new_connection(self):
469471
*self._connect_args,
470472
loop=self._loop,
471473
connection_class=self._connection_class,
474+
middlewares=self._middlewares,
472475
**self._connect_kwargs)
473476

474477
self._working_addr = con._addr
@@ -483,6 +486,7 @@ async def _get_new_connection(self):
483486
addr=self._working_addr,
484487
timeout=self._working_params.connect_timeout,
485488
config=self._working_config,
489+
middlewares=self._middlewares,
486490
params=self._working_params,
487491
connection_class=self._connection_class)
488492

@@ -784,13 +788,37 @@ def __await__(self):
784788
return self.pool._acquire(self.timeout).__await__()
785789

786790

791+
def middleware(f):
792+
"""Decorator for adding a middleware
793+
794+
Can be used like such
795+
796+
.. code-block:: python
797+
798+
@pool.middleware
799+
async def my_middleware(query, args, limit,
800+
timeout, return_status, *, handler, conn):
801+
print('do something before')
802+
result, stmt = await handler(query, args, limit,
803+
timeout, return_status)
804+
print('do something after')
805+
return result, stmt
806+
807+
my_pool = await pool.create_pool(middlewares=[my_middleware])
808+
"""
809+
async def middleware_factory(connection, handler):
810+
return functools.partial(f, connection=connection, handler=handler)
811+
return middleware_factory
812+
813+
787814
def create_pool(dsn=None, *,
788815
min_size=10,
789816
max_size=10,
790817
max_queries=50000,
791818
max_inactive_connection_lifetime=300.0,
792819
setup=None,
793820
init=None,
821+
middlewares=None,
794822
loop=None,
795823
connection_class=connection.Connection,
796824
**connect_kwargs):
@@ -866,6 +894,19 @@ def create_pool(dsn=None, *,
866894
or :meth:`Connection.set_type_codec() <\
867895
asyncpg.connection.Connection.set_type_codec>`.
868896
897+
:param middlewares:
898+
A list of middleware functions to be middleware just
899+
before a connection excecutes a statement.
900+
Syntax of a middleware is as follows:
901+
async def middleware_factory(connection, handler):
902+
async def middleware(query, args, limit, timeout, return_status):
903+
print('do something before')
904+
result, stmt = await handler(query, args, limit,
905+
timeout, return_status)
906+
print('do something after')
907+
return result, stmt
908+
return middleware
909+
869910
:param loop:
870911
An asyncio event loop instance. If ``None``, the default
871912
event loop will be used.
@@ -893,6 +934,7 @@ def create_pool(dsn=None, *,
893934
dsn,
894935
connection_class=connection_class,
895936
min_size=min_size, max_size=max_size,
896-
max_queries=max_queries, loop=loop, setup=setup, init=init,
937+
max_queries=max_queries, loop=loop, setup=setup,
938+
middlewares=middlewares, init=init,
897939
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
898940
**connect_kwargs)

‎docs/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
3030
* CPython header files. These can usually be obtained by installing
3131
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
3232
**python3-devel** on RHEL/Fedora.
33-
33+
* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)
3434
Once the above requirements are satisfied, run the following command
3535
in the root of the source checkout:
3636

‎tests/test_pool.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,48 @@ async def worker():
7676
tasks = [worker() for _ in range(n)]
7777
await asyncio.gather(*tasks)
7878

79+
async def test_pool_with_middleware(self):
80+
called = False
81+
82+
async def my_middleware_factory(connection, handler):
83+
async def middleware(query, args, limit, timeout, return_status):
84+
nonlocal called
85+
called = True
86+
return await handler(query, args, limit,
87+
timeout, return_status)
88+
return middleware
89+
90+
pool = await self.create_pool(database='postgres',
91+
min_size=1, max_size=1,
92+
middlewares=[my_middleware_factory])
93+
94+
con = await pool.acquire(timeout=5)
95+
await con.fetchval('SELECT 1')
96+
assert called
97+
98+
pool.terminate()
99+
del con
100+
101+
async def test_pool_with_middleware_decorator(self):
102+
called = False
103+
104+
@pg_pool.middleware
105+
async def my_middleware(query, args, limit, timeout, return_status,
106+
*, connection, handler):
107+
nonlocal called
108+
called = True
109+
return await handler(query, args, limit,
110+
timeout, return_status)
111+
112+
pool = await self.create_pool(database='postgres', min_size=1,
113+
max_size=1, middlewares=[my_middleware])
114+
con = await pool.acquire(timeout=5)
115+
await con.fetchval('SELECT 1')
116+
assert called
117+
118+
pool.terminate()
119+
del con
120+
79121
async def test_pool_03(self):
80122
pool = await self.create_pool(database='postgres',
81123
min_size=1, max_size=1)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /