Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions Lib/test/test_asyncio/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,17 +491,14 @@ async def client(addr):

def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
ANSWER = b'answer'

server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
if sys.platform.startswith('freebsd') or sys.platform.startswith('win'):
# bpo-35031: Some FreeBSD and Windows buildbots fail to run this test
# as the eof was not being received by the server if the payload
# size is not big enough. This behaviour only appears if the
# client is using TLS1.3.
client_context.options |= ssl.OP_NO_TLSv1_3
answer = None

def client(sock, addr):
nonlocal answer
sock.settimeout(self.TIMEOUT)

sock.connect(addr)
Expand All @@ -510,33 +507,36 @@ def client(sock, addr):

sock.start_tls(client_context)
sock.sendall(HELLO_MSG)

sock.shutdown(socket.SHUT_RDWR)
answer = sock.recv_all(len(ANSWER))
sock.close()

class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof, on_con_lost):
def __init__(self, on_con, on_con_lost):
self.on_con = on_con
self.on_eof = on_eof
self.on_con_lost = on_con_lost
self.data = b''
self.transport = None

def connection_made(self, tr):
self.transport = tr
self.on_con.set_result(tr)

def replace_transport(self, tr):
self.transport = tr

def data_received(self, data):
self.data += data

def eof_received(self):
self.on_eof.set_result(1)
if len(self.data) >= len(HELLO_MSG):
self.transport.write(ANSWER)

def connection_lost(self, exc):
self.transport = None
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)

async def main(proto, on_con, on_eof, on_con_lost):
async def main(proto, on_con, on_con_lost):
tr = await on_con
tr.write(HELLO_MSG)

Expand All @@ -547,16 +547,16 @@ async def main(proto, on_con, on_eof, on_con_lost):
server_side=True,
ssl_handshake_timeout=self.TIMEOUT)

await on_eof
proto.replace_transport(new_tr)

await on_con_lost
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()

async def run_main():
on_con = self.loop.create_future()
on_eof = self.loop.create_future()
on_con_lost = self.loop.create_future()
proto = ServerProto(on_con, on_eof, on_con_lost)
proto = ServerProto(on_con, on_con_lost)

server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
Expand All @@ -565,11 +565,12 @@ async def run_main():
with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT):
await asyncio.wait_for(
main(proto, on_con, on_eof, on_con_lost),
main(proto, on_con, on_con_lost),
loop=self.loop, timeout=self.TIMEOUT)

server.close()
await server.wait_closed()
self.assertEqual(answer, ANSWER)

self.loop.run_until_complete(run_main())

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid TimeoutError in test_asyncio: test_start_tls_server_1()