Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add typing to _ConnectionParameters and related functions #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
DanielNoord wants to merge 8 commits into MagicStack:master
base: master
Choose a base branch
Loading
from DanielNoord:type-parse-config
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions asyncpg/connect_utils.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import asyncio
import collections
from collections.abc import Callable
from collections.abc import Callable, Sequence
import enum
import functools
import getpass
Expand Down Expand Up @@ -41,31 +41,28 @@ class SSLMode(enum.IntEnum):
verify_full = 5

@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))
def parse(cls, sslmode: typing.Union[str, SSLMode]) -> SSLMode:
if isinstance(sslmode, str):
return getattr(cls, sslmode.replace('-', '_'))
return sslmode


class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'sslmode',
'ssl_negotiation',
'server_settings',
'target_session_attrs',
'krbsrvname',
'gsslib',
])
class _ConnectionParameters(typing.NamedTuple):
user: str
password: typing.Optional[str]
database: str
ssl: typing.Union[ssl_module.SSLContext, bool, str, SSLMode, None]
sslmode: SSLMode
ssl_negotiation: SSLNegotiation
server_settings: typing.Optional[typing.Dict[str, str]]
target_session_attrs: SessionAttribute
krbsrvname: typing.Optional[str]
gsslib: str


_ClientConfiguration = collections.namedtuple(
Expand Down Expand Up @@ -131,11 +128,13 @@ def _read_password_file(passfile: pathlib.Path) \


def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
*,
passfile: pathlib.Path,
hosts: Sequence[str],
ports: typing.List[int],
database: str,
user: str
) -> typing.Optional[str]:
"""Parse the pgpass file and return the matching password.

:return:
Expand Down Expand Up @@ -167,7 +166,9 @@ def _read_password_from_pgpass(
return None


def _validate_port_spec(hosts, port):
def _validate_port_spec(
hosts: Sequence[object], port: typing.Union[int, typing.List[int]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hosts: Sequence[object], port: typing.Union[int, typing.List[int]]
hosts: typing.List[str], port: typing.Union[int, typing.List[int]]

Copy link
Contributor Author

@DanielNoord DanielNoord Dec 18, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because of line 172. If your suggestion is correct we can remove the else branch. Since I didn't want to change the functionality of the code too much I just added what the code is able to handle instead of what it likely should be, if that makes sense.

Would you want me to change this? Or is keeping as is fine? If it is the latter, could you press the Merge button? :)

) -> typing.List[int]:
if isinstance(port, list):
# If there is a list of ports, its length must
# match that of the host list.
Expand All @@ -181,15 +182,20 @@ def _validate_port_spec(hosts, port):
return port


def _parse_hostlist(hostlist, port, *, unquote=False):
def _parse_hostlist(
hostlist: str,
port: typing.Union[int, typing.List[int]],
*,
unquote: bool = False,
) -> typing.Tuple[typing.List[str], typing.List[int]]:
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]

hosts = []
hostlist_ports = []
hosts: typing.List[str] = []
hostlist_ports: typing.List[int] = []

if not port:
portspec = os.environ.get('PGPORT')
Expand Down Expand Up @@ -267,10 +273,25 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
return (homedir / '.postgresql' / filename).resolve()


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
def _parse_connect_dsn_and_args(
*,
dsn: str,
host: typing.Union[str, typing.List[str], typing.Tuple[str]],
port: typing.Union[int, typing.List[int]],
user: typing.Optional[str],
password: typing.Optional[str],
passfile: typing.Union[str, pathlib.Path, None],
database: typing.Optional[str],
ssl: typing.Union[bool, None, str, SSLMode],
direct_tls: typing.Optional[bool],
server_settings: typing.Optional[typing.Dict[str, str]],
target_session_attrs: typing.Optional[str],
krbsrvname: typing.Optional[str],
gsslib: typing.Optional[str],
) -> typing.Tuple[
typing.List[typing.Union[str, typing.Tuple[str, int]]],
_ConnectionParameters,
]:
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -316,10 +337,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password = urllib.parse.unquote(dsn_password)

if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
query = {
key: val[-1]
for key, val in urllib.parse.parse_qs(
parsed.query, strict_parsing=True
).items()
}

if 'port' in query:
val = query.pop('port')
Expand Down Expand Up @@ -491,7 +514,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database=database, user=user,
passfile=passfile)

addrs = []
addrs: typing.List[typing.Union[str, typing.Tuple[str, int]]] = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
Expand Down
Loading

AltStyle によって変換されたページ (->オリジナル) /