Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8951230

Browse files
Improve typing for connect(port=)
1 parent 96c0a4c commit 8951230

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

‎asyncpg/cluster.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,14 @@ def _test_connection(self, timeout: int = 60) -> str:
543543
try:
544544
con: 'connection.Connection[typing.Any]' = \
545545
loop.run_until_complete(
546-
asyncpg.connect(# type: ignore[arg-type] # noqa: E501
546+
asyncpg.connect(
547547
database='postgres',
548548
user='postgres',
549549
timeout=5, loop=loop,
550-
**self._connection_addr
550+
**typing.cast(
551+
_ConnectionSpec,
552+
self._connection_addr
553+
)
551554
)
552555
)
553556
except (OSError, asyncio.TimeoutError,

‎asyncpg/connect_utils.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@
5858
]
5959
SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool]
6060
HostType = typing.Union[typing.List[str], str]
61-
PortType = typing.Union[typing.List[int], int]
61+
PortListType = typing.Union[
62+
typing.List[typing.Union[int, str]],
63+
typing.List[int],
64+
typing.List[str],
65+
]
66+
PortType = typing.Union[
67+
PortListType,
68+
int,
69+
str
70+
]
6271

6372

6473
class SSLMode(enum.IntEnum):
@@ -192,26 +201,42 @@ def _read_password_from_pgpass(
192201
return None
193202

194203

195-
def _validate_port_spec(hosts: typing.List[str],
196-
port: PortType) \
197-
-> typing.List[int]:
204+
@typing.overload
205+
def _validate_port_spec(
206+
hosts: typing.List[str],
207+
port: PortListType
208+
) -> typing.List[int]:
209+
...
210+
211+
212+
@typing.overload
213+
def _validate_port_spec(
214+
hosts: typing.List[str],
215+
port: typing.Union[int, str]
216+
) -> typing.List[int]:
217+
...
218+
219+
220+
def _validate_port_spec(
221+
hosts: typing.List[str],
222+
port: PortType
223+
) -> typing.List[int]:
198224
if isinstance(port, list):
199225
# If there is a list of ports, its length must
200226
# match that of the host list.
201227
if len(port) != len(hosts):
202228
raise exceptions.InterfaceError(
203229
'could not match {} port numbers to {} hosts'.format(
204230
len(port), len(hosts)))
231+
return [int(p) for p in port]
205232
else:
206-
port = [port for _ in range(len(hosts))]
207-
208-
return port
233+
return [int(port) for _ in range(len(hosts))]
209234

210235

211236
def _parse_hostlist(hostlist: str,
212237
port: typing.Optional[PortType],
213238
*, unquote: bool = False) \
214-
-> typing.Tuple[typing.List[str], typing.List[int]]:
239+
-> typing.Tuple[typing.List[str], PortListType]:
215240
if ',' in hostlist:
216241
# A comma-separated list of host addresses.
217242
hostspecs = hostlist.split(',')
@@ -242,7 +267,7 @@ def _parse_hostlist(hostlist: str,
242267
if hostspec[0] == '/':
243268
# Unix socket
244269
addr = hostspec
245-
hostspec_port = ''
270+
hostspec_port: str = ''
246271
elif hostspec[0] == '[':
247272
# IPv6 address
248273
m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
@@ -470,13 +495,10 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
470495
else:
471496
port = 5432
472497

473-
elif isinstance(port, (list, tuple)):
474-
port = [int(p) for p in port]
475-
476-
else:
498+
elif not isinstance(port, (list, tuple)):
477499
port = int(port)
478500

479-
port = _validate_port_spec(host, port)
501+
validated_ports = _validate_port_spec(host, port)
480502

481503
if user is None:
482504
user = os.getenv('PGUSER')
@@ -517,13 +539,13 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
517539

518540
if passfile_path is not None:
519541
password = _read_password_from_pgpass(
520-
hosts=auth_hosts, ports=port,
542+
hosts=auth_hosts, ports=validated_ports,
521543
database=database, user=user,
522544
passfile=passfile_path)
523545

524546
addrs: typing.List[AddrType] = []
525547
have_tcp_addrs = False
526-
for h, p in zip(host, port):
548+
for h, p in zip(host, validated_ports):
527549
if h.startswith('/'):
528550
# UNIX socket name
529551
if '.s.PGSQL.' not in h:

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /