add support for ALPN (closes #20188)

This commit is contained in:
Benjamin Peterson 2015-01-23 16:35:37 -05:00
parent 06140f2e04
commit cca2732a82
5 changed files with 231 additions and 29 deletions

View File

@ -673,6 +673,13 @@ Constants
.. versionadded:: 3.3 .. versionadded:: 3.3
.. data:: HAS_ALPN
Whether the OpenSSL library has built-in support for the *Application-Layer
Protocol Negotiation* TLS extension as described in :rfc:`7301`.
.. versionadded:: 3.5
.. data:: HAS_ECDH .. data:: HAS_ECDH
Whether the OpenSSL library has built-in support for Elliptic Curve-based Whether the OpenSSL library has built-in support for Elliptic Curve-based
@ -959,9 +966,18 @@ SSL sockets also have the following additional methods and attributes:
.. versionadded:: 3.3 .. versionadded:: 3.3
.. method:: SSLSocket.selected_alpn_protocol()
Return the protocol that was selected during the TLS handshake. If
:meth:`SSLContext.set_alpn_protocols` was not called, if the other party does
not support ALPN, or if the handshake has not happened yet, ``None`` is
returned.
.. versionadded:: 3.5
.. method:: SSLSocket.selected_npn_protocol() .. method:: SSLSocket.selected_npn_protocol()
Returns the higher-level protocol that was selected during the TLS/SSL Return the higher-level protocol that was selected during the TLS/SSL
handshake. If :meth:`SSLContext.set_npn_protocols` was not called, or handshake. If :meth:`SSLContext.set_npn_protocols` was not called, or
if the other party does not support NPN, or if the handshake has not yet if the other party does not support NPN, or if the handshake has not yet
happened, this will return ``None``. happened, this will return ``None``.
@ -1160,6 +1176,20 @@ to speed up repeated connections from the same clients.
when connected, the :meth:`SSLSocket.cipher` method of SSL sockets will when connected, the :meth:`SSLSocket.cipher` method of SSL sockets will
give the currently selected cipher. give the currently selected cipher.
.. method:: SSLContext.set_alpn_protocols(protocols)
Specify which protocols the socket should advertise during the SSL/TLS
handshake. It should be a list of ASCII strings, like ``['http/1.1',
'spdy/2']``, ordered by preference. The selection of a protocol will happen
during the handshake, and will play out according to :rfc:`7301`. After a
successful handshake, the :meth:`SSLSocket.selected_alpn_protocol` method will
return the agreed-upon protocol.
This method will raise :exc:`NotImplementedError` if :data:`HAS_ALPN` is
False.
.. versionadded:: 3.5
.. method:: SSLContext.set_npn_protocols(protocols) .. method:: SSLContext.set_npn_protocols(protocols)
Specify which protocols the socket should advertise during the SSL/TLS Specify which protocols the socket should advertise during the SSL/TLS
@ -1200,7 +1230,7 @@ to speed up repeated connections from the same clients.
Due to the early negotiation phase of the TLS connection, only limited Due to the early negotiation phase of the TLS connection, only limited
methods and attributes are usable like methods and attributes are usable like
:meth:`SSLSocket.selected_npn_protocol` and :attr:`SSLSocket.context`. :meth:`SSLSocket.selected_alpn_protocol` and :attr:`SSLSocket.context`.
:meth:`SSLSocket.getpeercert`, :meth:`SSLSocket.getpeercert`, :meth:`SSLSocket.getpeercert`, :meth:`SSLSocket.getpeercert`,
:meth:`SSLSocket.cipher` and :meth:`SSLSocket.compress` methods require that :meth:`SSLSocket.cipher` and :meth:`SSLSocket.compress` methods require that
the TLS connection has progressed beyond the TLS Client Hello and therefore the TLS connection has progressed beyond the TLS Client Hello and therefore

View File

@ -122,7 +122,7 @@ _import_symbols('OP_')
_import_symbols('ALERT_DESCRIPTION_') _import_symbols('ALERT_DESCRIPTION_')
_import_symbols('SSL_ERROR_') _import_symbols('SSL_ERROR_')
from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
from _ssl import _OPENSSL_API_VERSION from _ssl import _OPENSSL_API_VERSION
@ -374,6 +374,17 @@ class SSLContext(_SSLContext):
self._set_npn_protocols(protos) self._set_npn_protocols(protos)
def set_alpn_protocols(self, alpn_protocols):
protos = bytearray()
for protocol in alpn_protocols:
b = bytes(protocol, 'ascii')
if len(b) == 0 or len(b) > 255:
raise SSLError('ALPN protocols must be 1 to 255 in length')
protos.append(len(b))
protos.extend(b)
self._set_alpn_protocols(protos)
def _load_windows_store_certs(self, storename, purpose): def _load_windows_store_certs(self, storename, purpose):
certs = bytearray() certs = bytearray()
for cert, encoding, trust in enum_certificates(storename): for cert, encoding, trust in enum_certificates(storename):
@ -567,6 +578,13 @@ class SSLObject:
if _ssl.HAS_NPN: if _ssl.HAS_NPN:
return self._sslobj.selected_npn_protocol() return self._sslobj.selected_npn_protocol()
def selected_alpn_protocol(self):
"""Return the currently selected ALPN protocol as a string, or ``None``
if a next protocol was not negotiated or if ALPN is not supported by one
of the peers."""
if _ssl.HAS_ALPN:
return self._sslobj.selected_alpn_protocol()
def cipher(self): def cipher(self):
"""Return the currently selected cipher as a 3-tuple ``(name, """Return the currently selected cipher as a 3-tuple ``(name,
ssl_version, secret_bits)``.""" ssl_version, secret_bits)``."""
@ -783,6 +801,13 @@ class SSLSocket(socket):
else: else:
return self._sslobj.selected_npn_protocol() return self._sslobj.selected_npn_protocol()
def selected_alpn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_ALPN:
return None
else:
return self._sslobj.selected_alpn_protocol()
def cipher(self): def cipher(self):
self._checkClosed() self._checkClosed()
if not self._sslobj: if not self._sslobj:

View File

@ -1761,7 +1761,8 @@ else:
try: try:
self.sslconn = self.server.context.wrap_socket( self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True) self.sock, server_side=True)
self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
except (ssl.SSLError, ConnectionResetError) as e: except (ssl.SSLError, ConnectionResetError) as e:
# We treat ConnectionResetError as though it were an # We treat ConnectionResetError as though it were an
# SSLError - OpenSSL on Ubuntu abruptly closes the # SSLError - OpenSSL on Ubuntu abruptly closes the
@ -1869,7 +1870,8 @@ else:
def __init__(self, certificate=None, ssl_version=None, def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None, certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False, chatty=True, connectionchatty=False, starttls_server=False,
npn_protocols=None, ciphers=None, context=None): npn_protocols=None, alpn_protocols=None,
ciphers=None, context=None):
if context: if context:
self.context = context self.context = context
else: else:
@ -1884,6 +1886,8 @@ else:
self.context.load_cert_chain(certificate) self.context.load_cert_chain(certificate)
if npn_protocols: if npn_protocols:
self.context.set_npn_protocols(npn_protocols) self.context.set_npn_protocols(npn_protocols)
if alpn_protocols:
self.context.set_alpn_protocols(alpn_protocols)
if ciphers: if ciphers:
self.context.set_ciphers(ciphers) self.context.set_ciphers(ciphers)
self.chatty = chatty self.chatty = chatty
@ -1893,7 +1897,8 @@ else:
self.port = support.bind_port(self.sock) self.port = support.bind_port(self.sock)
self.flag = None self.flag = None
self.active = False self.active = False
self.selected_protocols = [] self.selected_npn_protocols = []
self.selected_alpn_protocols = []
self.shared_ciphers = [] self.shared_ciphers = []
self.conn_errors = [] self.conn_errors = []
threading.Thread.__init__(self) threading.Thread.__init__(self)
@ -2120,11 +2125,13 @@ else:
'compression': s.compression(), 'compression': s.compression(),
'cipher': s.cipher(), 'cipher': s.cipher(),
'peercert': s.getpeercert(), 'peercert': s.getpeercert(),
'client_alpn_protocol': s.selected_alpn_protocol(),
'client_npn_protocol': s.selected_npn_protocol(), 'client_npn_protocol': s.selected_npn_protocol(),
'version': s.version(), 'version': s.version(),
}) })
s.close() s.close()
stats['server_npn_protocols'] = server.selected_protocols stats['server_alpn_protocols'] = server.selected_alpn_protocols
stats['server_npn_protocols'] = server.selected_npn_protocols
stats['server_shared_ciphers'] = server.shared_ciphers stats['server_shared_ciphers'] = server.shared_ciphers
return stats return stats
@ -3022,6 +3029,55 @@ else:
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
self.fail("Non-DH cipher: " + cipher[0]) self.fail("Non-DH cipher: " + cipher[0])
def test_selected_alpn_protocol(self):
# selected_alpn_protocol() is None unless ALPN is used.
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.load_cert_chain(CERTFILE)
stats = server_params_test(context, context,
chatty=True, connectionchatty=True)
self.assertIs(stats['client_alpn_protocol'], None)
@unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
def test_selected_alpn_protocol_if_server_uses_alpn(self):
# selected_alpn_protocol() is None unless ALPN is used by the client.
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
client_context.load_verify_locations(CERTFILE)
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
server_context.load_cert_chain(CERTFILE)
server_context.set_alpn_protocols(['foo', 'bar'])
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True)
self.assertIs(stats['client_alpn_protocol'], None)
@unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
def test_alpn_protocols(self):
server_protocols = ['foo', 'bar', 'milkshake']
protocol_tests = [
(['foo', 'bar'], 'foo'),
(['bar', 'foo'], 'bar'),
(['milkshake'], 'milkshake'),
(['http/3.0', 'http/4.0'], 'foo')
]
for client_protocols, expected in protocol_tests:
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
server_context.load_cert_chain(CERTFILE)
server_context.set_alpn_protocols(server_protocols)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
client_context.load_cert_chain(CERTFILE)
client_context.set_alpn_protocols(client_protocols)
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True)
msg = "failed trying %s (s) and %s (c).\n" \
"was expecting %s, but got %%s from the %%s" \
% (str(server_protocols), str(client_protocols),
str(expected))
client_result = stats['client_alpn_protocol']
self.assertEqual(client_result, expected, msg % (client_result, "client"))
server_result = stats['server_alpn_protocols'][-1] \
if len(stats['server_alpn_protocols']) else 'nothing'
self.assertEqual(server_result, expected, msg % (server_result, "server"))
def test_selected_npn_protocol(self): def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used # selected_npn_protocol() is None unless NPN is used
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)

View File

@ -203,6 +203,9 @@ Core and Builtins
Library Library
------- -------
- Issue #20188: Support Application-Layer Protocol Negotiation (ALPN) in the ssl
module.
- Issue #23133: Pickling of ipaddress objects now produces more compact and - Issue #23133: Pickling of ipaddress objects now produces more compact and
portable representation. portable representation.

View File

@ -109,6 +109,11 @@ struct py_ssl_library_code {
# define HAVE_SNI 0 # define HAVE_SNI 0
#endif #endif
/* ALPN added in OpenSSL 1.0.2 */
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
# define HAVE_ALPN
#endif
enum py_ssl_error { enum py_ssl_error {
/* these mirror ssl.h */ /* these mirror ssl.h */
PY_SSL_ERROR_NONE, PY_SSL_ERROR_NONE,
@ -180,9 +185,13 @@ typedef struct {
PyObject_HEAD PyObject_HEAD
SSL_CTX *ctx; SSL_CTX *ctx;
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
char *npn_protocols; unsigned char *npn_protocols;
int npn_protocols_len; int npn_protocols_len;
#endif #endif
#ifdef HAVE_ALPN
unsigned char *alpn_protocols;
int alpn_protocols_len;
#endif
#ifndef OPENSSL_NO_TLSEXT #ifndef OPENSSL_NO_TLSEXT
PyObject *set_hostname; PyObject *set_hostname;
#endif #endif
@ -1460,7 +1469,20 @@ static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) {
if (out == NULL) if (out == NULL)
Py_RETURN_NONE; Py_RETURN_NONE;
return PyUnicode_FromStringAndSize((char *) out, outlen); return PyUnicode_FromStringAndSize((char *)out, outlen);
}
#endif
#ifdef HAVE_ALPN
static PyObject *PySSL_selected_alpn_protocol(PySSLSocket *self) {
const unsigned char *out;
unsigned int outlen;
SSL_get0_alpn_selected(self->ssl, &out, &outlen);
if (out == NULL)
Py_RETURN_NONE;
return PyUnicode_FromStringAndSize((char *)out, outlen);
} }
#endif #endif
@ -2053,6 +2075,9 @@ static PyMethodDef PySSLMethods[] = {
{"version", (PyCFunction)PySSL_version, METH_NOARGS}, {"version", (PyCFunction)PySSL_version, METH_NOARGS},
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
{"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS}, {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS},
#endif
#ifdef HAVE_ALPN
{"selected_alpn_protocol", (PyCFunction)PySSL_selected_alpn_protocol, METH_NOARGS},
#endif #endif
{"compression", (PyCFunction)PySSL_compression, METH_NOARGS}, {"compression", (PyCFunction)PySSL_compression, METH_NOARGS},
{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
@ -2159,6 +2184,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
self->npn_protocols = NULL; self->npn_protocols = NULL;
#endif #endif
#ifdef HAVE_ALPN
self->alpn_protocols = NULL;
#endif
#ifndef OPENSSL_NO_TLSEXT #ifndef OPENSSL_NO_TLSEXT
self->set_hostname = NULL; self->set_hostname = NULL;
#endif #endif
@ -2218,7 +2246,10 @@ context_dealloc(PySSLContext *self)
context_clear(self); context_clear(self);
SSL_CTX_free(self->ctx); SSL_CTX_free(self->ctx);
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
PyMem_Free(self->npn_protocols); PyMem_FREE(self->npn_protocols);
#endif
#ifdef HAVE_ALPN
PyMem_FREE(self->alpn_protocols);
#endif #endif
Py_TYPE(self)->tp_free(self); Py_TYPE(self)->tp_free(self);
} }
@ -2244,6 +2275,23 @@ set_ciphers(PySSLContext *self, PyObject *args)
Py_RETURN_NONE; Py_RETURN_NONE;
} }
static int
do_protocol_selection(unsigned char **out, unsigned char *outlen,
const unsigned char *remote_protocols, unsigned int remote_protocols_len,
unsigned char *our_protocols, unsigned int our_protocols_len)
{
if (our_protocols == NULL) {
our_protocols = (unsigned char*)"";
our_protocols_len = 0;
}
SSL_select_next_proto(out, outlen,
remote_protocols, remote_protocols_len,
our_protocols, our_protocols_len);
return SSL_TLSEXT_ERR_OK;
}
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
/* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */ /* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */
static int static int
@ -2254,10 +2302,10 @@ _advertiseNPN_cb(SSL *s,
PySSLContext *ssl_ctx = (PySSLContext *) args; PySSLContext *ssl_ctx = (PySSLContext *) args;
if (ssl_ctx->npn_protocols == NULL) { if (ssl_ctx->npn_protocols == NULL) {
*data = (unsigned char *) ""; *data = (unsigned char *)"";
*len = 0; *len = 0;
} else { } else {
*data = (unsigned char *) ssl_ctx->npn_protocols; *data = ssl_ctx->npn_protocols;
*len = ssl_ctx->npn_protocols_len; *len = ssl_ctx->npn_protocols_len;
} }
@ -2270,23 +2318,9 @@ _selectNPN_cb(SSL *s,
const unsigned char *server, unsigned int server_len, const unsigned char *server, unsigned int server_len,
void *args) void *args)
{ {
PySSLContext *ssl_ctx = (PySSLContext *) args; PySSLContext *ctx = (PySSLContext *)args;
return do_protocol_selection(out, outlen, server, server_len,
unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols; ctx->npn_protocols, ctx->npn_protocols_len);
int client_len;
if (client == NULL) {
client = (unsigned char *) "";
client_len = 0;
} else {
client_len = ssl_ctx->npn_protocols_len;
}
SSL_select_next_proto(out, outlen,
server, server_len,
client, client_len);
return SSL_TLSEXT_ERR_OK;
} }
#endif #endif
@ -2329,6 +2363,50 @@ _set_npn_protocols(PySSLContext *self, PyObject *args)
#endif #endif
} }
#ifdef HAVE_ALPN
static int
_selectALPN_cb(SSL *s,
const unsigned char **out, unsigned char *outlen,
const unsigned char *client_protocols, unsigned int client_protocols_len,
void *args)
{
PySSLContext *ctx = (PySSLContext *)args;
return do_protocol_selection((unsigned char **)out, outlen,
client_protocols, client_protocols_len,
ctx->alpn_protocols, ctx->alpn_protocols_len);
}
#endif
static PyObject *
_set_alpn_protocols(PySSLContext *self, PyObject *args)
{
#ifdef HAVE_ALPN
Py_buffer protos;
if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos))
return NULL;
PyMem_FREE(self->alpn_protocols);
self->alpn_protocols = PyMem_Malloc(protos.len);
if (!self->alpn_protocols)
return PyErr_NoMemory();
memcpy(self->alpn_protocols, protos.buf, protos.len);
self->alpn_protocols_len = protos.len;
PyBuffer_Release(&protos);
if (SSL_CTX_set_alpn_protos(self->ctx, self->alpn_protocols, self->alpn_protocols_len))
return PyErr_NoMemory();
SSL_CTX_set_alpn_select_cb(self->ctx, _selectALPN_cb, self);
PyBuffer_Release(&protos);
Py_RETURN_NONE;
#else
PyErr_SetString(PyExc_NotImplementedError,
"The ALPN extension requires OpenSSL 1.0.2 or later.");
return NULL;
#endif
}
static PyObject * static PyObject *
get_verify_mode(PySSLContext *self, void *c) get_verify_mode(PySSLContext *self, void *c)
{ {
@ -3307,6 +3385,8 @@ static struct PyMethodDef context_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"set_ciphers", (PyCFunction) set_ciphers, {"set_ciphers", (PyCFunction) set_ciphers,
METH_VARARGS, NULL}, METH_VARARGS, NULL},
{"_set_alpn_protocols", (PyCFunction) _set_alpn_protocols,
METH_VARARGS, NULL},
{"_set_npn_protocols", (PyCFunction) _set_npn_protocols, {"_set_npn_protocols", (PyCFunction) _set_npn_protocols,
METH_VARARGS, NULL}, METH_VARARGS, NULL},
{"load_cert_chain", (PyCFunction) load_cert_chain, {"load_cert_chain", (PyCFunction) load_cert_chain,
@ -4502,6 +4582,14 @@ PyInit__ssl(void)
Py_INCREF(r); Py_INCREF(r);
PyModule_AddObject(m, "HAS_NPN", r); PyModule_AddObject(m, "HAS_NPN", r);
#ifdef HAVE_ALPN
r = Py_True;
#else
r = Py_False;
#endif
Py_INCREF(r);
PyModule_AddObject(m, "HAS_ALPN", r);
/* Mappings for error codes */ /* Mappings for error codes */
err_codes_to_names = PyDict_New(); err_codes_to_names = PyDict_New();
err_names_to_codes = PyDict_New(); err_names_to_codes = PyDict_New();