diff --git a/Lib/ssl.py b/Lib/ssl.py index c072cd960bf..aa301295aee 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -75,10 +75,10 @@ from _ssl import ( SSL_ERROR_INVALID_ERROR_CODE, ) -from socket import socket, AF_INET, SOCK_STREAM, error from socket import getnameinfo as _getnameinfo from socket import error as socket_error from socket import dup as _dup +from socket import socket, AF_INET, SOCK_STREAM import base64 # for DER-to-PEM translation import traceback @@ -296,6 +296,14 @@ class SSLSocket(socket): self._sslobj = None socket.shutdown(self, how) + def unwrap (self): + if self._sslobj: + s = self._sslobj.shutdown() + self._sslobj = None + return s + else: + raise ValueError("No SSL wrapper around " + str(self)) + def _real_close(self): self._sslobj = None # self._closed = True diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 9e36e8067a8..a40a35d6f62 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -279,6 +279,15 @@ else: self.write("OK\n".encode("ASCII", "strict")) if not self.wrap_conn(): return + elif (self.server.starttls_server and self.sslconn + and amsg.strip() == 'ENDTLS'): + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") + self.write("OK\n".encode("ASCII", "strict")) + self.sock = self.sslconn.unwrap() + self.sslconn = None + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: connection is now unencrypted...\n") else: if (support.verbose and self.server.connectionchatty): @@ -868,7 +877,7 @@ else: def testSTARTTLS (self): - msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4") + msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6") server = ThreadedEchoServer(CERTFILE, ssl_version=ssl.PROTOCOL_TLSv1, @@ -910,8 +919,16 @@ else: " client: read %s from server, starting TLS...\n" % repr(msg)) conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) - wrapped = True + elif (indata == "ENDTLS" and + str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): + if support.verbose: + msg = str(outdata, 'ASCII', 'replace') + sys.stdout.write( + " client: read %s from server, ending TLS...\n" + % repr(msg)) + s = conn.unwrap() + wrapped = False else: if support.verbose: msg = str(outdata, 'ASCII', 'replace') @@ -922,7 +939,7 @@ else: if wrapped: conn.write("over\n".encode("ASCII", "strict")) else: - s.send("over\n") + s.send("over\n".encode("ASCII", "strict")) if wrapped: conn.close() else: diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 48318a8abca..d9cbbd051fe 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -1370,6 +1370,42 @@ PyDoc_STRVAR(PySSL_SSLread_doc, \n\ Read up to len bytes from the SSL socket."); +static PyObject *PySSL_SSLshutdown(PySSLObject *self) +{ + int err; + PySocketSockObject *sock + = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); + + /* Guard against closed socket */ + if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + + PySSL_BEGIN_ALLOW_THREADS + err = SSL_shutdown(self->ssl); + if (err == 0) { + /* we need to call it again to finish the shutdown */ + err = SSL_shutdown(self->ssl); + } + PySSL_END_ALLOW_THREADS + + if (err < 0) + return PySSL_SetError(self, err, __FILE__, __LINE__); + else { + Py_INCREF(sock); + return (PyObject *) sock; + } +} + +PyDoc_STRVAR(PySSL_SSLshutdown_doc, +"shutdown(s) -> socket\n\ +\n\ +Does the SSL shutdown handshake with the remote end, and returns\n\ +the underlying socket object."); + + static PyMethodDef PySSLMethods[] = { {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS}, {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, @@ -1381,6 +1417,8 @@ static PyMethodDef PySSLMethods[] = { {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, PySSL_peercert_doc}, {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, + {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, + PySSL_SSLshutdown_doc}, {NULL, NULL} }; @@ -1480,6 +1518,8 @@ fails or if it does provide enough data to seed PRNG."); #endif + + /* List of functions exported by this module. */ static PyMethodDef PySSL_methods[] = {