Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7528e6d

Browse files
Clean up TypeVar naems and mark type aliases
1 parent 8951230 commit 7528e6d

File tree

8 files changed

+358
-318
lines changed

8 files changed

+358
-318
lines changed

‎asyncpg/connect_utils.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import time
2323
import typing
24+
import typing_extensions
2425
import urllib.parse
2526
import warnings
2627

@@ -37,33 +38,44 @@
3738
if typing.TYPE_CHECKING:
3839
from . import connection
3940

40-
_Connection = typing.TypeVar(
41-
'_Connection',
41+
_ConnectionT = typing.TypeVar(
42+
'_ConnectionT',
4243
bound='connection.Connection[typing.Any]'
4344
)
44-
_Protocol = typing.TypeVar('_Protocol', bound='protocol.Protocol[typing.Any]')
45-
_AsyncProtocol=typing.TypeVar(
46-
'_AsyncProtocol', bound='asyncio.protocols.Protocol'
45+
_ProtocolT = typing.TypeVar(
46+
'_ProtocolT',
47+
bound='protocol.Protocol[typing.Any]'
4748
)
48-
_Record = typing.TypeVar('_Record', bound=protocol.Record)
49-
_SSLMode = typing.TypeVar('_SSLMode', bound='SSLMode')
49+
_AsyncProtocolT = typing.TypeVar(
50+
'_AsyncProtocolT', bound='asyncio.protocols.Protocol'
51+
)
52+
_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record)
53+
_SSLModeT = typing.TypeVar('_SSLModeT', bound='SSLMode')
5054

51-
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _AsyncProtocol]
52-
AddrType = typing.Union[typing.Tuple[str, int], str]
53-
SSLStringValues = compat.Literal[
55+
_TPTupleType: typing_extensions.TypeAlias = typing.Tuple[
56+
asyncio.WriteTransport,
57+
_AsyncProtocolT
58+
]
59+
_SSLStringValues = compat.Literal[
5460
'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full'
5561
]
56-
_ParsedSSLType = typing.Union[
62+
AddrType: typing_extensions.TypeAlias = typing.Union[
63+
typing.Tuple[str, int],
64+
str
65+
]
66+
_ParsedSSLType: typing_extensions.TypeAlias = typing.Union[
5767
ssl_module.SSLContext, compat.Literal[False]
5868
]
59-
SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool]
60-
HostType = typing.Union[typing.List[str], str]
61-
PortListType = typing.Union[
69+
SSLType: typing_extensions.TypeAlias = typing.Union[
70+
_ParsedSSLType, _SSLStringValues, bool
71+
]
72+
HostType: typing_extensions.TypeAlias = typing.Union[typing.List[str], str]
73+
PortListType: typing_extensions.TypeAlias = typing.Union[
6274
typing.List[typing.Union[int, str]],
6375
typing.List[int],
6476
typing.List[str],
6577
]
66-
PortType = typing.Union[
78+
PortType: typing_extensions.TypeAlias = typing.Union[
6779
PortListType,
6880
int,
6981
str
@@ -80,13 +92,13 @@ class SSLMode(enum.IntEnum):
8092

8193
@classmethod
8294
def parse(
83-
cls: typing.Type[_SSLMode],
84-
sslmode: typing.Union[str, _SSLMode]
85-
) -> _SSLMode:
95+
cls: typing.Type[_SSLModeT],
96+
sslmode: typing.Union[str, _SSLModeT]
97+
) -> _SSLModeT:
8698
if isinstance(sslmode, cls):
8799
return sslmode
88100
return typing.cast(
89-
_SSLMode,
101+
_SSLModeT,
90102
getattr(cls, typing.cast(str, sslmode).replace('-', '_'))
91103
)
92104

@@ -798,14 +810,14 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
798810

799811
@typing.overload
800812
async def _create_ssl_connection(
801-
protocol_factory: typing.Callable[[], _Protocol],
813+
protocol_factory: typing.Callable[[], _ProtocolT],
802814
host: str,
803815
port: int,
804816
*,
805817
loop: asyncio.AbstractEventLoop,
806818
ssl_context: ssl_module.SSLContext,
807819
ssl_is_advisory: typing.Optional[bool] = False
808-
) -> _TPTupleType[_Protocol]:
820+
) -> _TPTupleType[_ProtocolT]:
809821
...
810822

811823

@@ -824,7 +836,7 @@ async def _create_ssl_connection(
824836

825837
async def _create_ssl_connection(
826838
protocol_factory: typing.Union[
827-
typing.Callable[[], _Protocol],
839+
typing.Callable[[], _ProtocolT],
828840
typing.Callable[[], '_CancelProto']
829841
],
830842
host: str,
@@ -886,7 +898,7 @@ async def _create_ssl_connection(
886898

887899
try:
888900
new_tr, pg_proto = typing.cast(
889-
typing.Tuple[asyncio.WriteTransport, _Protocol],
901+
typing.Tuple[asyncio.WriteTransport, _ProtocolT],
890902
await conn_factory(sock=sock)
891903
)
892904
pg_proto.is_ssl = do_ssl_upgrade
@@ -903,9 +915,9 @@ async def _connect_addr(
903915
timeout: float,
904916
params: _ConnectionParameters,
905917
config: _ClientConfiguration,
906-
connection_class: typing.Type[_Connection],
907-
record_class: typing.Type[_Record]
908-
) -> _Connection:
918+
connection_class: typing.Type[_ConnectionT],
919+
record_class: typing.Type[_RecordT]
920+
) -> _ConnectionT:
909921
assert loop is not None
910922

911923
if timeout <= 0:
@@ -956,22 +968,22 @@ async def __connect_addr(
956968
addr: AddrType,
957969
loop: asyncio.AbstractEventLoop,
958970
config: _ClientConfiguration,
959-
connection_class: typing.Type[_Connection],
960-
record_class: typing.Type[_Record],
971+
connection_class: typing.Type[_ConnectionT],
972+
record_class: typing.Type[_RecordT],
961973
params_input: _ConnectionParameters,
962-
) -> _Connection:
974+
) -> _ConnectionT:
963975
connected = _create_future(loop)
964976

965977
proto_factory: typing.Callable[
966-
[], 'protocol.Protocol[_Record]'
978+
[], 'protocol.Protocol[_RecordT]'
967979
] = lambda: protocol.Protocol(
968980
addr, connected, params, record_class, loop)
969981

970982
if isinstance(addr, str):
971983
# UNIX socket
972984
connector = typing.cast(
973985
typing.Coroutine[
974-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
986+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
975987
],
976988
loop.create_unix_connection(proto_factory, addr)
977989
)
@@ -981,7 +993,7 @@ async def __connect_addr(
981993
# SSL connection
982994
connector = typing.cast(
983995
typing.Coroutine[
984-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
996+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
985997
],
986998
loop.create_connection(
987999
proto_factory, *addr, ssl=params.ssl
@@ -995,7 +1007,7 @@ async def __connect_addr(
9951007
else:
9961008
connector = typing.cast(
9971009
typing.Coroutine[
998-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
1010+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
9991011
],
10001012
loop.create_connection(proto_factory, *addr)
10011013
)
@@ -1053,10 +1065,10 @@ async def _connect(
10531065
*,
10541066
loop: typing.Optional[asyncio.AbstractEventLoop],
10551067
timeout: float,
1056-
connection_class: typing.Type[_Connection],
1057-
record_class: typing.Type[_Record],
1068+
connection_class: typing.Type[_ConnectionT],
1069+
record_class: typing.Type[_RecordT],
10581070
**kwargs: typing.Any
1059-
) -> _Connection:
1071+
) -> _ConnectionT:
10601072
if loop is None:
10611073
loop = asyncio.get_event_loop()
10621074

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /