[Python-checkins] bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)

Andrew Svetlov webhook-mailer at python.org
Tue May 29 05:02:52 EDT 2018


https://github.com/python/cpython/commit/2179022d94937d7b0600b0dc192ca6fa5f53d830
commit: 2179022d94937d7b0600b0dc192ca6fa5f53d830
branch: master
author: Yury Selivanov <yury at magic.io>
committer: Andrew Svetlov <andrew.svetlov at gmail.com>
date: 2018年05月29日T12:02:40+03:00
summary:
bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)
files:
A Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
M Lib/asyncio/sslproto.py
M Lib/test/test_asyncio/test_sslproto.py
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index ab43e93b28bc..a6d382ecd3de 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -295,7 +295,7 @@ def get_extra_info(self, name, default=None):
 return self._ssl_protocol._get_extra_info(name, default)
 
 def set_protocol(self, protocol):
- self._ssl_protocol._app_protocol = protocol
+ self._ssl_protocol._set_app_protocol(protocol)
 
 def get_protocol(self):
 return self._ssl_protocol._app_protocol
@@ -440,9 +440,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
 
 self._waiter = waiter
 self._loop = loop
- self._app_protocol = app_protocol
- self._app_protocol_is_buffer = \
- isinstance(app_protocol, protocols.BufferedProtocol)
+ self._set_app_protocol(app_protocol)
 self._app_transport = _SSLProtocolTransport(self._loop, self)
 # _SSLPipe instance (None until the connection is made)
 self._sslpipe = None
@@ -454,6 +452,11 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
 self._call_connection_made = call_connection_made
 self._ssl_handshake_timeout = ssl_handshake_timeout
 
+ def _set_app_protocol(self, app_protocol):
+ self._app_protocol = app_protocol
+ self._app_protocol_is_buffer = \
+ isinstance(app_protocol, protocols.BufferedProtocol)
+
 def _wakeup_waiter(self, exc=None):
 if self._waiter is None:
 return
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 1b2f9d2a3a2a..fa9cbd56ed42 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):
 
 server_context = test_utils.simple_server_sslcontext()
 client_context = test_utils.simple_client_sslcontext()
+ client_con_made_calls = 0
 
 def serve(sock):
 sock.settimeout(self.TIMEOUT)
@@ -315,20 +316,21 @@ def serve(sock):
 data = sock.recv_all(len(HELLO_MSG))
 self.assertEqual(len(data), len(HELLO_MSG))
 
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
 sock.shutdown(socket.SHUT_RDWR)
 sock.close()
 
- class ClientProto(asyncio.BufferedProtocol):
- def __init__(self, on_data, on_eof):
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
 self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
 self.buf = bytearray(1)
 
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
 
 def get_buffer(self, sizehint):
 return self.buf
@@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
 assert nsize == 1
 self.on_data.set_result(bytes(self.buf[:nsize]))
 
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
 def eof_received(self):
 self.on_eof.set_result(True)
 
 async def client(addr):
 await asyncio.sleep(0.5, loop=self.loop)
 
- on_data = self.loop.create_future()
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
 on_eof = self.loop.create_future()
 
 tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
+ lambda: ClientProtoFirst(on_data1), *addr)
 
 tr.write(HELLO_MSG)
 new_tr = await self.loop.start_tls(tr, proto, client_context)
 
- self.assertEqual(await on_data, b'O')
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
 new_tr.write(HELLO_MSG)
 await on_eof
 
 new_tr.close()
 
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 self.loop.run_until_complete(
 asyncio.wait_for(client(srv.addr),
diff --git a/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
new file mode 100644
index 000000000000..39e8e615d8c4
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
@@ -0,0 +1 @@
+Support protocol type switching in SSLTransport.set_protocol().


More information about the Python-checkins mailing list

AltStyle によって変換されたページ (->オリジナル) /