Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 292290d

Browse files
Clean up types
1 parent 73c50c4 commit 292290d

File tree

8 files changed

+179
-75
lines changed

8 files changed

+179
-75
lines changed

‎asyncpg/cluster.py‎

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class ClusterError(Exception):
6868

6969

7070
class Cluster:
71+
_data_dir: str
72+
_pg_config_path: typing.Optional[str]
73+
_pg_bin_dir: typing.Optional[str]
74+
_pg_ctl: typing.Optional[str]
75+
_daemon_pid: typing.Optional[int]
76+
_daemon_process: typing.Optional['subprocess.Popen[bytes]']
77+
_connection_addr: typing.Optional[_ConnectionSpec]
78+
_connection_spec_override: typing.Optional[_ConnectionSpec]
79+
7180
def __init__(self, data_dir: str, *,
7281
pg_config_path: typing.Optional[str] = None) -> None:
7382
self._data_dir = data_dir
@@ -76,11 +85,11 @@ def __init__(self, data_dir: str, *,
7685
os.environ.get('PGINSTALLATION')
7786
or os.environ.get('PGBIN')
7887
)
79-
self._pg_ctl: typing.Optional[str] = None
80-
self._daemon_pid: typing.Optional[int] = None
81-
self._daemon_process: typing.Optional[subprocess.Popen[bytes]] = None
82-
self._connection_addr: typing.Optional[_ConnectionSpec] = None
83-
self._connection_spec_override: typing.Optional[_ConnectionSpec] = None
88+
self._pg_ctl = None
89+
self._daemon_pid = None
90+
self._daemon_process = None
91+
self._connection_addr = None
92+
self._connection_spec_override = None
8493

8594
def get_pg_version(self) -> 'types.ServerVersion':
8695
return self._pg_version
@@ -653,6 +662,9 @@ def __init__(self, *,
653662

654663

655664
class HotStandbyCluster(TempCluster):
665+
_master: _ConnectionSpec
666+
_repl_user: str
667+
656668
def __init__(self, *,
657669
master: _ConnectionSpec, replication_user: str,
658670
data_dir_suffix: typing.Optional[str] = None,
@@ -739,16 +751,16 @@ def get_status(self) -> str:
739751
return 'running'
740752

741753
def init(self, **settings: str) -> str:
742-
pass
754+
...
743755

744756
def start(self, wait: int = 60, **settings: typing.Any) -> None:
745-
pass
757+
...
746758

747759
def stop(self, wait: int = 60) -> None:
748-
pass
760+
...
749761

750762
def destroy(self) -> None:
751-
pass
763+
...
752764

753765
def reset_hba(self) -> None:
754766
raise ClusterError('cannot modify HBA records of unmanaged cluster')

‎asyncpg/connection.py‎

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,41 @@ class Connection(typing.Generic[_Record], metaclass=ConnectionMeta):
157157
'_log_listeners', '_termination_listeners', '_cancellations',
158158
'_source_traceback', '__weakref__')
159159

160+
_protocol: '_cprotocol.BaseProtocol[_Record]'
161+
_transport: typing.Any
162+
_loop: asyncio.AbstractEventLoop
163+
_top_xact: typing.Optional[transaction.Transaction]
164+
_aborted: bool
165+
_pool_release_ctr: int
166+
_stmt_cache: '_StatementCache'
167+
_stmts_to_close: typing.Set[
168+
'_cprotocol.PreparedStatementState[typing.Any]'
169+
]
170+
_listeners: typing.Dict[str, typing.Set['_Callback']]
171+
_server_version: types.ServerVersion
172+
_server_caps: 'ServerCapabilities'
173+
_intro_query: str
174+
_reset_query: typing.Optional[str]
175+
_proxy: typing.Optional['_pool.PoolConnectionProxy[typing.Any]']
176+
_stmt_exclusive_section: '_Atomic'
177+
_config: connect_utils._ClientConfiguration
178+
_params: connect_utils._ConnectionParameters
179+
_addr: typing.Union[typing.Tuple[str, int], str]
180+
_log_listeners: typing.Set['_Callback']
181+
_termination_listeners: typing.Set['_Callback']
182+
_cancellations: typing.Set['asyncio.Task[typing.Any]']
183+
_source_traceback: typing.Optional[str]
184+
160185
def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
161186
transport: typing.Any,
162187
loop: asyncio.AbstractEventLoop,
163188
addr: typing.Union[typing.Tuple[str, int], str],
164189
config: connect_utils._ClientConfiguration,
165190
params: connect_utils._ConnectionParameters) -> None:
166-
self._protocol: '_cprotocol.BaseProtocol[_Record]' = protocol
191+
self._protocol = protocol
167192
self._transport = transport
168193
self._loop = loop
169-
self._top_xact: typing.Optional[transaction.Transaction] = None
194+
self._top_xact = None
170195
self._aborted = False
171196
# Incremented every time the connection is released back to a pool.
172197
# Used to catch invalid references to connection-related resources
@@ -184,14 +209,12 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
184209
_weak_maybe_gc_stmt, weakref.ref(self)),
185210
max_lifetime=config.max_cached_statement_lifetime)
186211

187-
self._stmts_to_close: typing.Set[
188-
'_cprotocol.PreparedStatementState[typing.Any]'
189-
] = set()
212+
self._stmts_to_close = set()
190213

191-
self._listeners: typing.Dict[str, typing.Set[_Callback]] = {}
192-
self._log_listeners: typing.Set[_Callback] = set()
193-
self._cancellations: typing.Set[asyncio.Task[typing.Any]] = set()
194-
self._termination_listeners: typing.Set[_Callback] = set()
214+
self._listeners = {}
215+
self._log_listeners = set()
216+
self._cancellations = set()
217+
self._termination_listeners = set()
195218

196219
settings = self._protocol.get_settings()
197220
ver_string = settings.server_version
@@ -206,10 +229,8 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
206229
else:
207230
self._intro_query = introspection.INTRO_LOOKUP_TYPES
208231

209-
self._reset_query: typing.Optional[str] = None
210-
self._proxy: typing.Optional[
211-
'_pool.PoolConnectionProxy[typing.Any]'
212-
] = None
232+
self._reset_query = None
233+
self._proxy = None
213234

214235
# Used to serialize operations that might involve anonymous
215236
# statements. Specifically, we want to make the following
@@ -221,7 +242,7 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
221242
self._stmt_exclusive_section = _Atomic()
222243

223244
if loop.get_debug():
224-
self._source_traceback: typing.Optional[str] = _extract_stack()
245+
self._source_traceback = _extract_stack()
225246
else:
226247
self._source_traceback = None
227248

@@ -2007,7 +2028,7 @@ def _set_proxy(
20072028
self._proxy = proxy
20082029

20092030
def _check_listeners(self,
2010-
listeners: 'typing.Sized',
2031+
listeners: typing.Sized,
20112032
listener_type: str) -> None:
20122033
if listeners:
20132034
count = len(listeners)
@@ -2927,6 +2948,11 @@ class _StatementCacheEntry(typing.Generic[_Record]):
29272948

29282949
__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
29292950

2951+
_query: _StatementCacheKey[_Record]
2952+
_statement: '_cprotocol.PreparedStatementState[_Record]'
2953+
_cache: '_StatementCache'
2954+
_cleanup_cb: typing.Optional[asyncio.TimerHandle]
2955+
29302956
def __init__(
29312957
self,
29322958
cache: '_StatementCache',
@@ -2936,21 +2962,27 @@ def __init__(
29362962
self._cache = cache
29372963
self._query = query
29382964
self._statement = statement
2939-
self._cleanup_cb: typing.Optional[asyncio.TimerHandle] = None
2965+
self._cleanup_cb = None
29402966

29412967

29422968
class _StatementCache:
29432969

29442970
__slots__ = ('_loop', '_entries', '_max_size', '_on_remove',
29452971
'_max_lifetime')
29462972

2973+
_loop: asyncio.AbstractEventLoop
2974+
_entries: 'collections.OrderedDict[_StatementCacheKey[typing.Any], _StatementCacheEntry[typing.Any]]' # noqa: E501
2975+
_max_size: int
2976+
_on_remove: OnRemove[typing.Any]
2977+
_max_lifetime: float
2978+
29472979
def __init__(self, *, loop: asyncio.AbstractEventLoop,
29482980
max_size: int, on_remove: OnRemove[typing.Any],
29492981
max_lifetime: float) -> None:
2950-
self._loop: asyncio.AbstractEventLoop = loop
2951-
self._max_size: int = max_size
2952-
self._on_remove: OnRemove[typing.Any] = on_remove
2953-
self._max_lifetime: float = max_lifetime
2982+
self._loop = loop
2983+
self._max_size = max_size
2984+
self._on_remove = on_remove
2985+
self._max_lifetime = max_lifetime
29542986

29552987
# We use an OrderedDict for LRU implementation. Operations:
29562988
#
@@ -2969,10 +3001,7 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
29693001
# So new entries and hits are always promoted to the end of the
29703002
# entries dict, whereas the unused one will group in the
29713003
# beginning of it.
2972-
self._entries: collections.OrderedDict[
2973-
_StatementCacheKey[typing.Any],
2974-
_StatementCacheEntry[typing.Any]
2975-
] = collections.OrderedDict()
3004+
self._entries = collections.OrderedDict()
29763005

29773006
def __len__(self) -> int:
29783007
return len(self._entries)
@@ -3148,6 +3177,8 @@ def from_callable(
31483177
class _Atomic:
31493178
__slots__ = ('_acquired',)
31503179

3180+
_acquired: int
3181+
31513182
def __init__(self) -> None:
31523183
self._acquired = 0
31533184

‎asyncpg/connresource.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
if typing.TYPE_CHECKING:
16-
from . import connection as _connection
16+
from . import connection as _conn
1717

1818

1919
_Callable = typing.TypeVar('_Callable', bound=typing.Callable[..., typing.Any])
@@ -35,8 +35,11 @@ def _check(self: 'ConnectionResource',
3535
class ConnectionResource:
3636
__slots__ = ('_connection', '_con_release_ctr')
3737

38+
_connection: '_conn.Connection[typing.Any]'
39+
_con_release_ctr: int
40+
3841
def __init__(
39-
self, connection: '_connection.Connection[typing.Any]'
42+
self, connection: '_conn.Connection[typing.Any]'
4043
) -> None:
4144
self._connection = connection
4245
self._con_release_ctr = connection._pool_release_ctr

‎asyncpg/cursor.py‎

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,19 @@ class CursorFactory(connresource.ConnectionResource, typing.Generic[_Record]):
3939
'_record_class',
4040
)
4141

42+
_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
43+
_args: typing.Sequence[typing.Any]
44+
_prefetch: typing.Optional[int]
45+
_query: str
46+
_timeout: typing.Optional[float]
47+
_record_class: typing.Optional[typing.Type[_Record]]
48+
4249
@typing.overload
4350
def __init__(
4451
self: 'CursorFactory[_Record]',
4552
connection: '_connection.Connection[_Record]',
4653
query: str,
47-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
54+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
4855
args: typing.Sequence[typing.Any],
4956
prefetch: typing.Optional[int],
5057
timeout: typing.Optional[float],
@@ -57,7 +64,7 @@ def __init__(
5764
self: 'CursorFactory[_Record]',
5865
connection: '_connection.Connection[typing.Any]',
5966
query: str,
60-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
67+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
6168
args: typing.Sequence[typing.Any],
6269
prefetch: typing.Optional[int],
6370
timeout: typing.Optional[float],
@@ -69,7 +76,7 @@ def __init__(
6976
self,
7077
connection: '_connection.Connection[typing.Any]',
7178
query: str,
72-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
79+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
7380
args: typing.Sequence[typing.Any],
7481
prefetch: typing.Optional[int],
7582
timeout: typing.Optional[float],
@@ -130,12 +137,19 @@ class BaseCursor(connresource.ConnectionResource, typing.Generic[_Record]):
130137
'_record_class',
131138
)
132139

140+
_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
141+
_args: typing.Sequence[typing.Any]
142+
_portal_name: typing.Optional[str]
143+
_exhausted: bool
144+
_query: str
145+
_record_class: typing.Optional[typing.Type[_Record]]
146+
133147
@typing.overload
134148
def __init__(
135149
self: 'BaseCursor[_Record]',
136150
connection: '_connection.Connection[_Record]',
137151
query: str,
138-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
152+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
139153
args: typing.Sequence[typing.Any],
140154
record_class: None
141155
) -> None:
@@ -146,7 +160,7 @@ def __init__(
146160
self: 'BaseCursor[_Record]',
147161
connection: '_connection.Connection[typing.Any]',
148162
query: str,
149-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
163+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
150164
args: typing.Sequence[typing.Any],
151165
record_class: typing.Type[_Record]
152166
) -> None:
@@ -156,7 +170,7 @@ def __init__(
156170
self,
157171
connection: '_connection.Connection[typing.Any]',
158172
query: str,
159-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
173+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
160174
args: typing.Sequence[typing.Any],
161175
record_class: typing.Optional[typing.Type[_Record]]
162176
) -> None:
@@ -165,7 +179,7 @@ def __init__(
165179
self._state = state
166180
if state is not None:
167181
state.attach()
168-
self._portal_name: typing.Optional[str] = None
182+
self._portal_name = None
169183
self._exhausted = False
170184
self._query = query
171185
self._record_class = record_class
@@ -260,12 +274,16 @@ class CursorIterator(BaseCursor[_Record]):
260274

261275
__slots__ = ('_buffer', '_prefetch', '_timeout')
262276

277+
_buffer: typing.Deque[_Record]
278+
_prefetch: int
279+
_timeout: typing.Optional[float]
280+
263281
@typing.overload
264282
def __init__(
265283
self: 'CursorIterator[_Record]',
266284
connection: '_connection.Connection[_Record]',
267285
query: str,
268-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
286+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
269287
args: typing.Sequence[typing.Any],
270288
record_class: None,
271289
prefetch: int,
@@ -278,7 +296,7 @@ def __init__(
278296
self: 'CursorIterator[_Record]',
279297
connection: '_connection.Connection[typing.Any]',
280298
query: str,
281-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
299+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
282300
args: typing.Sequence[typing.Any],
283301
record_class: typing.Type[_Record],
284302
prefetch: int,
@@ -290,7 +308,7 @@ def __init__(
290308
self,
291309
connection: '_connection.Connection[typing.Any]',
292310
query: str,
293-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
311+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
294312
args: typing.Sequence[typing.Any],
295313
record_class: typing.Optional[typing.Type[_Record]],
296314
prefetch: int,
@@ -302,7 +320,7 @@ def __init__(
302320
raise exceptions.InterfaceError(
303321
'prefetch argument must be greater than zero')
304322

305-
self._buffer: typing.Deque[_Record] = collections.deque()
323+
self._buffer = collections.deque()
306324
self._prefetch = prefetch
307325
self._timeout = timeout
308326

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /