mirror of https://github.com/python/cpython
gh-63284: Add support for TLS-PSK (pre-shared key) to the ssl module (#103181)
Add support for TLS-PSK (pre-shared key) to the ssl module. --------- Co-authored-by: Oleg Iarygin <oleg@arhadthedev.net> Co-authored-by: Gregory P. Smith <greg@krypto.org>
This commit is contained in:
parent
fb202af447
commit
e954ac7205
|
@ -2006,6 +2006,94 @@ to speed up repeated connections from the same clients.
|
|||
>>> ssl.create_default_context().verify_mode # doctest: +SKIP
|
||||
<VerifyMode.CERT_REQUIRED: 2>
|
||||
|
||||
.. method:: SSLContext.set_psk_client_callback(callback)
|
||||
|
||||
Enables TLS-PSK (pre-shared key) authentication on a client-side connection.
|
||||
|
||||
In general, certificate based authentication should be preferred over this method.
|
||||
|
||||
The parameter ``callback`` is a callable object with the signature:
|
||||
``def callback(hint: str | None) -> tuple[str | None, bytes]``.
|
||||
The ``hint`` parameter is an optional identity hint sent by the server.
|
||||
The return value is a tuple in the form (client-identity, psk).
|
||||
Client-identity is an optional string which may be used by the server to
|
||||
select a corresponding PSK for the client. The string must be less than or
|
||||
equal to ``256`` octets when UTF-8 encoded. PSK is a
|
||||
:term:`bytes-like object` representing the pre-shared key. Return a zero
|
||||
length PSK to reject the connection.
|
||||
|
||||
Setting ``callback`` to :const:`None` removes any existing callback.
|
||||
|
||||
.. note::
|
||||
When using TLS 1.3:
|
||||
|
||||
- the ``hint`` parameter is always :const:`None`.
|
||||
- client-identity must be a non-empty string.
|
||||
|
||||
Example usage::
|
||||
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
context.set_ciphers('PSK')
|
||||
|
||||
# A simple lambda:
|
||||
psk = bytes.fromhex('c0ffee')
|
||||
context.set_psk_client_callback(lambda hint: (None, psk))
|
||||
|
||||
# A table using the hint from the server:
|
||||
psk_table = { 'ServerId_1': bytes.fromhex('c0ffee'),
|
||||
'ServerId_2': bytes.fromhex('facade')
|
||||
}
|
||||
def callback(hint):
|
||||
return 'ClientId_1', psk_table.get(hint, b'')
|
||||
context.set_psk_client_callback(callback)
|
||||
|
||||
.. versionadded:: 3.13
|
||||
|
||||
.. method:: SSLContext.set_psk_server_callback(callback, identity_hint=None)
|
||||
|
||||
Enables TLS-PSK (pre-shared key) authentication on a server-side connection.
|
||||
|
||||
In general, certificate based authentication should be preferred over this method.
|
||||
|
||||
The parameter ``callback`` is a callable object with the signature:
|
||||
``def callback(identity: str | None) -> bytes``.
|
||||
The ``identity`` parameter is an optional identity sent by the client which can
|
||||
be used to select a corresponding PSK.
|
||||
The return value is a :term:`bytes-like object` representing the pre-shared key.
|
||||
Return a zero length PSK to reject the connection.
|
||||
|
||||
Setting ``callback`` to :const:`None` removes any existing callback.
|
||||
|
||||
The parameter ``identity_hint`` is an optional identity hint string sent to
|
||||
the client. The string must be less than or equal to ``256`` octets when
|
||||
UTF-8 encoded.
|
||||
|
||||
.. note::
|
||||
When using TLS 1.3 the ``identity_hint`` parameter is not sent to the client.
|
||||
|
||||
Example usage::
|
||||
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
context.set_ciphers('PSK')
|
||||
|
||||
# A simple lambda:
|
||||
psk = bytes.fromhex('c0ffee')
|
||||
context.set_psk_server_callback(lambda identity: psk)
|
||||
|
||||
# A table using the identity of the client:
|
||||
psk_table = { 'ClientId_1': bytes.fromhex('c0ffee'),
|
||||
'ClientId_2': bytes.fromhex('facade')
|
||||
}
|
||||
def callback(identity):
|
||||
return psk_table.get(identity, b'')
|
||||
context.set_psk_server_callback(callback, 'ServerId_1')
|
||||
|
||||
.. versionadded:: 3.13
|
||||
|
||||
.. index:: single: certificates
|
||||
|
||||
.. index:: single: X509 certificate
|
||||
|
|
|
@ -826,6 +826,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
|
|||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call_exception_handler));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call_soon));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(callback));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cancel));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(capath));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(category));
|
||||
|
@ -971,6 +972,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
|
|||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(hook));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(id));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(ident));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(identity_hint));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(ignore));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(imag));
|
||||
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(importlib));
|
||||
|
|
|
@ -315,6 +315,7 @@ struct _Py_global_strings {
|
|||
STRUCT_FOR_ID(call)
|
||||
STRUCT_FOR_ID(call_exception_handler)
|
||||
STRUCT_FOR_ID(call_soon)
|
||||
STRUCT_FOR_ID(callback)
|
||||
STRUCT_FOR_ID(cancel)
|
||||
STRUCT_FOR_ID(capath)
|
||||
STRUCT_FOR_ID(category)
|
||||
|
@ -460,6 +461,7 @@ struct _Py_global_strings {
|
|||
STRUCT_FOR_ID(hook)
|
||||
STRUCT_FOR_ID(id)
|
||||
STRUCT_FOR_ID(ident)
|
||||
STRUCT_FOR_ID(identity_hint)
|
||||
STRUCT_FOR_ID(ignore)
|
||||
STRUCT_FOR_ID(imag)
|
||||
STRUCT_FOR_ID(importlib)
|
||||
|
|
|
@ -824,6 +824,7 @@ extern "C" {
|
|||
INIT_ID(call), \
|
||||
INIT_ID(call_exception_handler), \
|
||||
INIT_ID(call_soon), \
|
||||
INIT_ID(callback), \
|
||||
INIT_ID(cancel), \
|
||||
INIT_ID(capath), \
|
||||
INIT_ID(category), \
|
||||
|
@ -969,6 +970,7 @@ extern "C" {
|
|||
INIT_ID(hook), \
|
||||
INIT_ID(id), \
|
||||
INIT_ID(ident), \
|
||||
INIT_ID(identity_hint), \
|
||||
INIT_ID(ignore), \
|
||||
INIT_ID(imag), \
|
||||
INIT_ID(importlib), \
|
||||
|
|
|
@ -786,6 +786,9 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
|
|||
string = &_Py_ID(call_soon);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
string = &_Py_ID(callback);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
string = &_Py_ID(cancel);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
|
@ -1221,6 +1224,9 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
|
|||
string = &_Py_ID(ident);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
string = &_Py_ID(identity_hint);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
string = &_Py_ID(ignore);
|
||||
assert(_PyUnicode_CheckConsistency(string, 1));
|
||||
_PyUnicode_InternInPlace(interp, &string);
|
||||
|
|
|
@ -4236,6 +4236,105 @@ class ThreadedTests(unittest.TestCase):
|
|||
self.assertEqual(str(e.exception),
|
||||
'Session refers to a different SSLContext.')
|
||||
|
||||
@requires_tls_version('TLSv1_2')
|
||||
def test_psk(self):
|
||||
psk = bytes.fromhex('deadbeef')
|
||||
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
client_context.set_ciphers('PSK')
|
||||
client_context.set_psk_client_callback(lambda hint: (None, psk))
|
||||
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
server_context.set_ciphers('PSK')
|
||||
server_context.set_psk_server_callback(lambda identity: psk)
|
||||
|
||||
# correct PSK should connect
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# incorrect PSK should fail
|
||||
incorrect_psk = bytes.fromhex('cafebabe')
|
||||
client_context.set_psk_client_callback(lambda hint: (None, incorrect_psk))
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# identity_hint and client_identity should be sent to the other side
|
||||
identity_hint = 'identity-hint'
|
||||
client_identity = 'client-identity'
|
||||
|
||||
def client_callback(hint):
|
||||
self.assertEqual(hint, identity_hint)
|
||||
return client_identity, psk
|
||||
|
||||
def server_callback(identity):
|
||||
self.assertEqual(identity, client_identity)
|
||||
return psk
|
||||
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# adding client callback to server or vice versa raises an exception
|
||||
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK server callback'):
|
||||
client_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK client callback'):
|
||||
server_context.set_psk_client_callback(client_callback)
|
||||
|
||||
# test with UTF-8 identities
|
||||
identity_hint = '身份暗示' # Translation: "Identity hint"
|
||||
client_identity = '客户身份' # Translation: "Customer identity"
|
||||
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
@requires_tls_version('TLSv1_3')
|
||||
def test_psk_tls1_3(self):
|
||||
psk = bytes.fromhex('deadbeef')
|
||||
identity_hint = 'identity-hint'
|
||||
client_identity = 'client-identity'
|
||||
|
||||
def client_callback(hint):
|
||||
# identity_hint is not sent to the client in TLS 1.3
|
||||
self.assertIsNone(hint)
|
||||
return client_identity, psk
|
||||
|
||||
def server_callback(identity):
|
||||
self.assertEqual(identity, client_identity)
|
||||
return psk
|
||||
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
client_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
client_context.set_ciphers('PSK')
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
server_context.set_ciphers('PSK')
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
|
||||
@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
|
||||
class TestPostHandshakeAuth(unittest.TestCase):
|
||||
|
|
|
@ -1482,6 +1482,7 @@ Ajith Ramachandran
|
|||
Dhushyanth Ramasamy
|
||||
Ashwin Ramaswami
|
||||
Jeff Ramnani
|
||||
Grant Ramsay
|
||||
Bayard Randel
|
||||
Varpu Rantala
|
||||
Brodie Rao
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Added support for TLS-PSK (pre-shared key) mode to the :mod:`ssl` module.
|
224
Modules/_ssl.c
224
Modules/_ssl.c
|
@ -301,6 +301,8 @@ typedef struct {
|
|||
BIO *keylog_bio;
|
||||
/* Cached module state, also used in SSLSocket and SSLSession code. */
|
||||
_sslmodulestate *state;
|
||||
PyObject *psk_client_callback;
|
||||
PyObject *psk_server_callback;
|
||||
} PySSLContext;
|
||||
|
||||
typedef struct {
|
||||
|
@ -3123,6 +3125,8 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
|
|||
self->alpn_protocols = NULL;
|
||||
self->set_sni_cb = NULL;
|
||||
self->state = get_ssl_state(module);
|
||||
self->psk_client_callback = NULL;
|
||||
self->psk_server_callback = NULL;
|
||||
|
||||
/* Don't check host name by default */
|
||||
if (proto_version == PY_SSL_VERSION_TLS_CLIENT) {
|
||||
|
@ -3235,6 +3239,8 @@ context_clear(PySSLContext *self)
|
|||
Py_CLEAR(self->set_sni_cb);
|
||||
Py_CLEAR(self->msg_cb);
|
||||
Py_CLEAR(self->keylog_filename);
|
||||
Py_CLEAR(self->psk_client_callback);
|
||||
Py_CLEAR(self->psk_server_callback);
|
||||
if (self->keylog_bio != NULL) {
|
||||
PySSL_BEGIN_ALLOW_THREADS
|
||||
BIO_free_all(self->keylog_bio);
|
||||
|
@ -4662,6 +4668,222 @@ _ssl__SSLContext_get_ca_certs_impl(PySSLContext *self, int binary_form)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
static unsigned int psk_client_callback(SSL *s,
|
||||
const char *hint,
|
||||
char *identity,
|
||||
unsigned int max_identity_len,
|
||||
unsigned char *psk,
|
||||
unsigned int max_psk_len)
|
||||
{
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject *callback = NULL;
|
||||
|
||||
PySSLSocket *ssl = SSL_get_app_data(s);
|
||||
if (ssl == NULL || ssl->ctx == NULL) {
|
||||
goto error;
|
||||
}
|
||||
callback = ssl->ctx->psk_client_callback;
|
||||
if (callback == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
PyObject *hint_str = (hint != NULL && hint[0] != '\0') ?
|
||||
PyUnicode_DecodeUTF8(hint, strlen(hint), "strict") :
|
||||
Py_NewRef(Py_None);
|
||||
if (hint_str == NULL) {
|
||||
/* The remote side has sent an invalid UTF-8 string
|
||||
* (breaking the standard), drop the connection without
|
||||
* raising a decode exception. */
|
||||
PyErr_Clear();
|
||||
goto error;
|
||||
}
|
||||
PyObject *result = PyObject_CallFunctionObjArgs(callback, hint_str, NULL);
|
||||
Py_DECREF(hint_str);
|
||||
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
const char *psk_;
|
||||
const char *identity_;
|
||||
Py_ssize_t psk_len_;
|
||||
Py_ssize_t identity_len_ = 0;
|
||||
if (!PyArg_ParseTuple(result, "z#y#", &identity_, &identity_len_, &psk_, &psk_len_)) {
|
||||
Py_DECREF(result);
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (identity_len_ + 1 > max_identity_len || psk_len_ > max_psk_len) {
|
||||
Py_DECREF(result);
|
||||
goto error;
|
||||
}
|
||||
memcpy(psk, psk_, psk_len_);
|
||||
if (identity_ != NULL) {
|
||||
memcpy(identity, identity_, identity_len_);
|
||||
}
|
||||
identity[identity_len_] = 0;
|
||||
|
||||
Py_DECREF(result);
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
return (unsigned int)psk_len_;
|
||||
|
||||
error:
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_WriteUnraisable(callback);
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
_ssl._SSLContext.set_psk_client_callback
|
||||
callback: object
|
||||
|
||||
[clinic start generated code]*/
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_client_callback_impl(PySSLContext *self,
|
||||
PyObject *callback)
|
||||
/*[clinic end generated code: output=0aba86f6ed75119e input=7627bae0e5ee7635]*/
|
||||
{
|
||||
if (self->protocol == PY_SSL_VERSION_TLS_SERVER) {
|
||||
_setSSLError(get_state_ctx(self),
|
||||
"Cannot add PSK client callback to a "
|
||||
"PROTOCOL_TLS_SERVER context", 0, __FILE__, __LINE__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
SSL_psk_client_cb_func ssl_callback;
|
||||
if (callback == Py_None) {
|
||||
callback = NULL;
|
||||
// Delete the existing callback
|
||||
ssl_callback = NULL;
|
||||
} else {
|
||||
if (!PyCallable_Check(callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "callback must be callable");
|
||||
return NULL;
|
||||
}
|
||||
ssl_callback = psk_client_callback;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->psk_client_callback);
|
||||
Py_XINCREF(callback);
|
||||
|
||||
self->psk_client_callback = callback;
|
||||
SSL_CTX_set_psk_client_callback(self->ctx, ssl_callback);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static unsigned int psk_server_callback(SSL *s,
|
||||
const char *identity,
|
||||
unsigned char *psk,
|
||||
unsigned int max_psk_len)
|
||||
{
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject *callback = NULL;
|
||||
|
||||
PySSLSocket *ssl = SSL_get_app_data(s);
|
||||
if (ssl == NULL || ssl->ctx == NULL) {
|
||||
goto error;
|
||||
}
|
||||
callback = ssl->ctx->psk_server_callback;
|
||||
if (callback == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
PyObject *identity_str = (identity != NULL && identity[0] != '\0') ?
|
||||
PyUnicode_DecodeUTF8(identity, strlen(identity), "strict") :
|
||||
Py_NewRef(Py_None);
|
||||
if (identity_str == NULL) {
|
||||
/* The remote side has sent an invalid UTF-8 string
|
||||
* (breaking the standard), drop the connection without
|
||||
* raising a decode exception. */
|
||||
PyErr_Clear();
|
||||
goto error;
|
||||
}
|
||||
PyObject *result = PyObject_CallFunctionObjArgs(callback, identity_str, NULL);
|
||||
Py_DECREF(identity_str);
|
||||
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
char *psk_;
|
||||
Py_ssize_t psk_len_;
|
||||
if (PyBytes_AsStringAndSize(result, &psk_, &psk_len_) < 0) {
|
||||
Py_DECREF(result);
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (psk_len_ > max_psk_len) {
|
||||
Py_DECREF(result);
|
||||
goto error;
|
||||
}
|
||||
memcpy(psk, psk_, psk_len_);
|
||||
|
||||
Py_DECREF(result);
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
return (unsigned int)psk_len_;
|
||||
|
||||
error:
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_WriteUnraisable(callback);
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
_ssl._SSLContext.set_psk_server_callback
|
||||
callback: object
|
||||
identity_hint: str(accept={str, NoneType}) = None
|
||||
|
||||
[clinic start generated code]*/
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_server_callback_impl(PySSLContext *self,
|
||||
PyObject *callback,
|
||||
const char *identity_hint)
|
||||
/*[clinic end generated code: output=1f4d6a4e09a92b03 input=65d4b6022aa85ea3]*/
|
||||
{
|
||||
if (self->protocol == PY_SSL_VERSION_TLS_CLIENT) {
|
||||
_setSSLError(get_state_ctx(self),
|
||||
"Cannot add PSK server callback to a "
|
||||
"PROTOCOL_TLS_CLIENT context", 0, __FILE__, __LINE__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
SSL_psk_server_cb_func ssl_callback;
|
||||
if (callback == Py_None) {
|
||||
callback = NULL;
|
||||
// Delete the existing callback and hint
|
||||
ssl_callback = NULL;
|
||||
identity_hint = NULL;
|
||||
} else {
|
||||
if (!PyCallable_Check(callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "callback must be callable");
|
||||
return NULL;
|
||||
}
|
||||
ssl_callback = psk_server_callback;
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_psk_identity_hint(self->ctx, identity_hint) != 1) {
|
||||
PyErr_SetString(PyExc_ValueError, "failed to set identity hint");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->psk_server_callback);
|
||||
Py_XINCREF(callback);
|
||||
|
||||
self->psk_server_callback = callback;
|
||||
SSL_CTX_set_psk_server_callback(self->ctx, ssl_callback);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
|
||||
static PyGetSetDef context_getsetlist[] = {
|
||||
{"check_hostname", (getter) get_check_hostname,
|
||||
|
@ -4716,6 +4938,8 @@ static struct PyMethodDef context_methods[] = {
|
|||
_SSL__SSLCONTEXT_CERT_STORE_STATS_METHODDEF
|
||||
_SSL__SSLCONTEXT_GET_CA_CERTS_METHODDEF
|
||||
_SSL__SSLCONTEXT_GET_CIPHERS_METHODDEF
|
||||
_SSL__SSLCONTEXT_SET_PSK_CLIENT_CALLBACK_METHODDEF
|
||||
_SSL__SSLCONTEXT_SET_PSK_SERVER_CALLBACK_METHODDEF
|
||||
{NULL, NULL} /* sentinel */
|
||||
};
|
||||
|
||||
|
|
|
@ -1014,6 +1014,141 @@ exit:
|
|||
return return_value;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(_ssl__SSLContext_set_psk_client_callback__doc__,
|
||||
"set_psk_client_callback($self, /, callback)\n"
|
||||
"--\n"
|
||||
"\n");
|
||||
|
||||
#define _SSL__SSLCONTEXT_SET_PSK_CLIENT_CALLBACK_METHODDEF \
|
||||
{"set_psk_client_callback", _PyCFunction_CAST(_ssl__SSLContext_set_psk_client_callback), METH_FASTCALL|METH_KEYWORDS, _ssl__SSLContext_set_psk_client_callback__doc__},
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_client_callback_impl(PySSLContext *self,
|
||||
PyObject *callback);
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_client_callback(PySSLContext *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
|
||||
{
|
||||
PyObject *return_value = NULL;
|
||||
#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
|
||||
|
||||
#define NUM_KEYWORDS 1
|
||||
static struct {
|
||||
PyGC_Head _this_is_not_used;
|
||||
PyObject_VAR_HEAD
|
||||
PyObject *ob_item[NUM_KEYWORDS];
|
||||
} _kwtuple = {
|
||||
.ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
|
||||
.ob_item = { &_Py_ID(callback), },
|
||||
};
|
||||
#undef NUM_KEYWORDS
|
||||
#define KWTUPLE (&_kwtuple.ob_base.ob_base)
|
||||
|
||||
#else // !Py_BUILD_CORE
|
||||
# define KWTUPLE NULL
|
||||
#endif // !Py_BUILD_CORE
|
||||
|
||||
static const char * const _keywords[] = {"callback", NULL};
|
||||
static _PyArg_Parser _parser = {
|
||||
.keywords = _keywords,
|
||||
.fname = "set_psk_client_callback",
|
||||
.kwtuple = KWTUPLE,
|
||||
};
|
||||
#undef KWTUPLE
|
||||
PyObject *argsbuf[1];
|
||||
PyObject *callback;
|
||||
|
||||
args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf);
|
||||
if (!args) {
|
||||
goto exit;
|
||||
}
|
||||
callback = args[0];
|
||||
return_value = _ssl__SSLContext_set_psk_client_callback_impl(self, callback);
|
||||
|
||||
exit:
|
||||
return return_value;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(_ssl__SSLContext_set_psk_server_callback__doc__,
|
||||
"set_psk_server_callback($self, /, callback, identity_hint=None)\n"
|
||||
"--\n"
|
||||
"\n");
|
||||
|
||||
#define _SSL__SSLCONTEXT_SET_PSK_SERVER_CALLBACK_METHODDEF \
|
||||
{"set_psk_server_callback", _PyCFunction_CAST(_ssl__SSLContext_set_psk_server_callback), METH_FASTCALL|METH_KEYWORDS, _ssl__SSLContext_set_psk_server_callback__doc__},
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_server_callback_impl(PySSLContext *self,
|
||||
PyObject *callback,
|
||||
const char *identity_hint);
|
||||
|
||||
static PyObject *
|
||||
_ssl__SSLContext_set_psk_server_callback(PySSLContext *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
|
||||
{
|
||||
PyObject *return_value = NULL;
|
||||
#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
|
||||
|
||||
#define NUM_KEYWORDS 2
|
||||
static struct {
|
||||
PyGC_Head _this_is_not_used;
|
||||
PyObject_VAR_HEAD
|
||||
PyObject *ob_item[NUM_KEYWORDS];
|
||||
} _kwtuple = {
|
||||
.ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
|
||||
.ob_item = { &_Py_ID(callback), &_Py_ID(identity_hint), },
|
||||
};
|
||||
#undef NUM_KEYWORDS
|
||||
#define KWTUPLE (&_kwtuple.ob_base.ob_base)
|
||||
|
||||
#else // !Py_BUILD_CORE
|
||||
# define KWTUPLE NULL
|
||||
#endif // !Py_BUILD_CORE
|
||||
|
||||
static const char * const _keywords[] = {"callback", "identity_hint", NULL};
|
||||
static _PyArg_Parser _parser = {
|
||||
.keywords = _keywords,
|
||||
.fname = "set_psk_server_callback",
|
||||
.kwtuple = KWTUPLE,
|
||||
};
|
||||
#undef KWTUPLE
|
||||
PyObject *argsbuf[2];
|
||||
Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 1;
|
||||
PyObject *callback;
|
||||
const char *identity_hint = NULL;
|
||||
|
||||
args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 2, 0, argsbuf);
|
||||
if (!args) {
|
||||
goto exit;
|
||||
}
|
||||
callback = args[0];
|
||||
if (!noptargs) {
|
||||
goto skip_optional_pos;
|
||||
}
|
||||
if (args[1] == Py_None) {
|
||||
identity_hint = NULL;
|
||||
}
|
||||
else if (PyUnicode_Check(args[1])) {
|
||||
Py_ssize_t identity_hint_length;
|
||||
identity_hint = PyUnicode_AsUTF8AndSize(args[1], &identity_hint_length);
|
||||
if (identity_hint == NULL) {
|
||||
goto exit;
|
||||
}
|
||||
if (strlen(identity_hint) != (size_t)identity_hint_length) {
|
||||
PyErr_SetString(PyExc_ValueError, "embedded null character");
|
||||
goto exit;
|
||||
}
|
||||
}
|
||||
else {
|
||||
_PyArg_BadArgument("set_psk_server_callback", "argument 'identity_hint'", "str or None", args[1]);
|
||||
goto exit;
|
||||
}
|
||||
skip_optional_pos:
|
||||
return_value = _ssl__SSLContext_set_psk_server_callback_impl(self, callback, identity_hint);
|
||||
|
||||
exit:
|
||||
return return_value;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
_ssl_MemoryBIO_impl(PyTypeObject *type);
|
||||
|
||||
|
@ -1527,4 +1662,4 @@ exit:
|
|||
#ifndef _SSL_ENUM_CRLS_METHODDEF
|
||||
#define _SSL_ENUM_CRLS_METHODDEF
|
||||
#endif /* !defined(_SSL_ENUM_CRLS_METHODDEF) */
|
||||
/*[clinic end generated code: output=aa6b0a898b6077fe input=a9049054013a1b77]*/
|
||||
/*[clinic end generated code: output=6342ea0062ab16c7 input=a9049054013a1b77]*/
|
||||
|
|
Loading…
Reference in New Issue