Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 21 additions & 78 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import socket
import select
import time
import datetime
import gc
import os
import errno
Expand All @@ -39,9 +38,7 @@

PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = socket_helper.HOST
IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
IS_OPENSSL_3_0_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')

PROTOCOL_TO_TLS_VERSION = {}
Expand Down Expand Up @@ -258,53 +255,18 @@ def wrapper(*args, **kw):
return decorator


requires_minimum_version = unittest.skipUnless(
hasattr(ssl.SSLContext, 'minimum_version'),
"required OpenSSL >= 1.1.0g"
)


def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
if support.verbose:
sys.stdout.write(prefix + exc_format)

def _have_secp_curves():
if not ssl.HAS_ECDH:
return False
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
try:
ctx.set_ecdh_curve("secp384r1")
except ValueError:
return False
else:
return True


HAVE_SECP_CURVES = _have_secp_curves()


def utc_offset(): #NOTE: ignore issues like #1647654
# local time = utc time + utc offset
if time.daylight and time.localtime().tm_isdst > 0:
return -time.altzone # seconds
return -time.timezone

def asn1time(cert_time):
# Some versions of OpenSSL ignore seconds, see #18207
# 0.9.8.i
if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
fmt = "%b %d %H:%M:%S %Y GMT"
dt = datetime.datetime.strptime(cert_time, fmt)
dt = dt.replace(second=0)
cert_time = dt.strftime(fmt)
# %d adds leading zero but ASN1_TIME_print() uses leading space
if cert_time[4] == "0":
cert_time = cert_time[:4] + " " + cert_time[5:]

return cert_time

needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")

ignore_deprecation = warnings_helper.ignore_warnings(
category=DeprecationWarning
Expand Down Expand Up @@ -365,11 +327,12 @@ def test_constants(self):
ssl.CERT_REQUIRED
ssl.OP_CIPHER_SERVER_PREFERENCE
ssl.OP_SINGLE_DH_USE
if ssl.HAS_ECDH:
ssl.OP_SINGLE_ECDH_USE
ssl.OP_SINGLE_ECDH_USE
ssl.OP_NO_COMPRESSION
self.assertIn(ssl.HAS_SNI, {True, False})
self.assertIn(ssl.HAS_ECDH, {True, False})
self.assertEqual(ssl.HAS_SNI, True)
self.assertEqual(ssl.HAS_ECDH, True)
self.assertEqual(ssl.HAS_TLSv1_2, True)
self.assertEqual(ssl.HAS_TLSv1_3, True)
ssl.OP_NO_SSLv2
ssl.OP_NO_SSLv3
ssl.OP_NO_TLSv1
Expand Down Expand Up @@ -537,8 +500,8 @@ def test_openssl_version(self):
self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str)
# Some sanity checks follow
# >= 0.9
self.assertGreaterEqual(n, 0x900000)
# >= 1.1.1
self.assertGreaterEqual(n, 0x10101000)
# < 4.0
self.assertLess(n, 0x40000000)
major, minor, fix, patch, status = t
Expand All @@ -552,13 +515,13 @@ def test_openssl_version(self):
self.assertLessEqual(patch, 63)
self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15)
# Version string as returned by {Open,Libre}SSL, the format might change
if IS_LIBRESSL:
self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
(s, t, hex(n)))
else:
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
(s, t, hex(n)))

libressl_ver = f"LibreSSL {major:d}"
openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
self.assertTrue(
s.startswith((openssl_ver, libressl_ver)),
(s, t, hex(n))
)

@support.cpython_only
def test_refcycle(self):
Expand Down Expand Up @@ -1196,8 +1159,6 @@ def test_hostname_checks_common_name(self):
with self.assertRaises(AttributeError):
ctx.hostname_checks_common_name = True

@requires_minimum_version
@unittest.skipIf(IS_LIBRESSL, "see bpo-34001")
@ignore_deprecation
def test_min_max_version(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
Expand Down Expand Up @@ -1523,7 +1484,6 @@ def test_set_ecdh_curve(self):
self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")

@needs_sni
def test_sni_callback(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)

Expand All @@ -1538,7 +1498,6 @@ def dummycallback(sock, servername, ctx):
ctx.set_servername_callback(None)
ctx.set_servername_callback(dummycallback)

@needs_sni
def test_sni_callback_refcycle(self):
# Reference cycles through the servername callback are detected
# and cleared.
Expand Down Expand Up @@ -1578,8 +1537,8 @@ def test_get_ca_certs(self):
(('organizationalUnitName', 'http://www.cacert.org'),),
(('commonName', 'CA Cert Signing Authority'),),
(('emailAddress', 'support@cacert.org'),)),
'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
'notAfter': 'Mar 29 12:29:49 2033 GMT',
'notBefore': 'Mar 30 12:29:49 2003 GMT',
'serialNumber': '00',
'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
'subject': ((('organizationName', 'Root CA'),),
Expand Down Expand Up @@ -1609,7 +1568,6 @@ def test_load_default_certs(self):
self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')

@unittest.skipIf(sys.platform == "win32", "not-Windows specific")
@unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
def test_load_default_certs_env(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
with os_helper.EnvironmentVarGuard() as env:
Expand Down Expand Up @@ -2145,7 +2103,6 @@ def test_non_blocking_handshake(self):
def test_get_server_certificate(self):
_test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)

@needs_sni
def test_get_server_certificate_sni(self):
host, port = self.server_addr
server_names = []
Expand Down Expand Up @@ -2198,7 +2155,6 @@ def test_get_ca_certs_capath(self):
self.assertTrue(cert)
self.assertEqual(len(ctx.get_ca_certs()), 1)

@needs_sni
def test_context_setget(self):
# Check that the context of a connected socket can be replaced.
ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
Expand Down Expand Up @@ -3863,7 +3819,6 @@ def test_tls1_3(self):
})
self.assertEqual(s.version(), 'TLSv1.3')

@requires_minimum_version
@requires_tls_version('TLSv1_2')
@requires_tls_version('TLSv1')
@ignore_deprecation
Expand All @@ -3882,7 +3837,6 @@ def test_min_max_version_tlsv1_2(self):
s.connect((HOST, server.port))
self.assertEqual(s.version(), 'TLSv1.2')

@requires_minimum_version
@requires_tls_version('TLSv1_1')
@ignore_deprecation
def test_min_max_version_tlsv1_1(self):
Expand All @@ -3900,7 +3854,6 @@ def test_min_max_version_tlsv1_1(self):
s.connect((HOST, server.port))
self.assertEqual(s.version(), 'TLSv1.1')

@requires_minimum_version
@requires_tls_version('TLSv1_2')
@requires_tls_version('TLSv1')
@ignore_deprecation
Expand All @@ -3920,7 +3873,6 @@ def test_min_max_version_mismatch(self):
s.connect((HOST, server.port))
self.assertIn("alert", str(e.exception))

@requires_minimum_version
@requires_tls_version('SSLv3')
def test_min_max_version_sslv3(self):
client_context, server_context, hostname = testing_context()
Expand All @@ -3935,7 +3887,6 @@ def test_min_max_version_sslv3(self):
s.connect((HOST, server.port))
self.assertEqual(s.version(), 'SSLv3')

@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
def test_default_ecdh_curve(self):
# Issue #21015: elliptic curve-based Diffie Hellman key exchange
# should be enabled by default on SSL contexts.
Expand Down Expand Up @@ -4050,15 +4001,13 @@ def test_dh_params(self):
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
self.fail("Non-DH cipher: " + cipher[0])

@unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
@unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
def test_ecdh_curve(self):
# server secp384r1, client auto
client_context, server_context, hostname = testing_context()

server_context.set_ecdh_curve("secp384r1")
server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
server_context.minimum_version = ssl.TLSVersion.TLSv1_2
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True,
sni_name=hostname)
Expand All @@ -4067,7 +4016,7 @@ def test_ecdh_curve(self):
client_context, server_context, hostname = testing_context()
client_context.set_ecdh_curve("secp384r1")
server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
server_context.minimum_version = ssl.TLSVersion.TLSv1_2
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True,
sni_name=hostname)
Expand All @@ -4077,13 +4026,11 @@ def test_ecdh_curve(self):
client_context.set_ecdh_curve("prime256v1")
server_context.set_ecdh_curve("secp384r1")
server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
try:
server_context.minimum_version = ssl.TLSVersion.TLSv1_2
with self.assertRaises(ssl.SSLError):
server_params_test(client_context, server_context,
chatty=True, connectionchatty=True,
sni_name=hostname)
except ssl.SSLError:
self.fail("mismatch curve did not fail")

def test_selected_alpn_protocol(self):
# selected_alpn_protocol() is None unless ALPN is used.
Expand Down Expand Up @@ -4152,7 +4099,6 @@ def check_common_name(self, stats, name):
cert = stats['peercert']
self.assertIn((('commonName', name),), cert['subject'])

@needs_sni
def test_sni_callback(self):
calls = []
server_context, other_context, client_context = self.sni_contexts()
Expand Down Expand Up @@ -4193,7 +4139,6 @@ def servername_cb(ssl_sock, server_name, initial_context):
self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
self.assertEqual(calls, [])

@needs_sni
def test_sni_callback_alert(self):
# Returning a TLS alert is reflected to the connecting client
server_context, other_context, client_context = self.sni_contexts()
Expand All @@ -4207,7 +4152,6 @@ def cb_returning_alert(ssl_sock, server_name, initial_context):
sni_name='supermessage')
self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')

@needs_sni
def test_sni_callback_raising(self):
# Raising fails the connection with a TLS handshake failure alert.
server_context, other_context, client_context = self.sni_contexts()
Expand All @@ -4226,7 +4170,6 @@ def cb_raising(ssl_sock, server_name, initial_context):
'SSLV3_ALERT_HANDSHAKE_FAILURE')
self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)

@needs_sni
def test_sni_callback_wrong_return_type(self):
# Returning the wrong return type terminates the TLS connection
# with an internal error alert.
Expand Down