From 54cc54c1fe26328d62c70fa55350ad89868c1d61 Mon Sep 17 00:00:00 2001 From: Bill Janssen Date: Fri, 14 Dec 2007 22:08:56 +0000 Subject: [PATCH] update to fix leak in SSL code --- Lib/socket.py | 6 +- Lib/ssl.py | 46 +++++------ Lib/test/test_ssl.py | 184 ++++++++++++++++++++++++++++++++++++------- Modules/_ssl.c | 57 ++++++++++---- 4 files changed, 225 insertions(+), 68 deletions(-) diff --git a/Lib/socket.py b/Lib/socket.py index 62eb82dcd18..eb876731e35 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -174,11 +174,13 @@ class socket(_socket.socket): if self._closed: self.close() + def _real_close(self): + _socket.socket.close(self) + def close(self): self._closed = True if self._io_refs <= 0: - _socket.socket.close(self) - + self._real_close() def fromfd(fd, family, type, proto=0): """ fromfd(fd, family, type[, proto]) -> socket object diff --git a/Lib/ssl.py b/Lib/ssl.py index be138661112..c229cd3221f 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -80,6 +80,7 @@ from socket import getnameinfo as _getnameinfo from socket import error as socket_error from socket import dup as _dup import base64 # for DER-to-PEM translation +import traceback class SSLSocket(socket): @@ -94,16 +95,13 @@ class SSLSocket(socket): family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, suppress_ragged_eofs=True): - self._base = None - if sock is not None: - # copied this code from socket.accept() - fd = sock.fileno() - nfd = _dup(fd) - socket.__init__(self, family=sock.family, type=sock.type, - proto=sock.proto, fileno=nfd) + socket.__init__(self, + family=sock.family, + type=sock.type, + proto=sock.proto, + fileno=_dup(sock.fileno())) sock.close() - sock = None elif fileno is not None: socket.__init__(self, fileno=fileno) else: @@ -136,10 +134,6 @@ class SSLSocket(socket): self.close() raise x - if sock and (self.fileno() != sock.fileno()): - self._base = sock - else: - self._base = None self.keyfile = keyfile self.certfile = certfile self.cert_reqs = cert_reqs @@ -156,19 +150,23 @@ class SSLSocket(socket): # raise an exception here if you wish to check for spurious closes pass - def read(self, len=None, buffer=None): + def read(self, len=0, buffer=None): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" self._checkClosed() try: if buffer: - return self._sslobj.read(buffer, len) + v = self._sslobj.read(buffer, len) else: - return self._sslobj.read(len or 1024) + v = self._sslobj.read(len or 1024) + return v except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: - return b'' + if buffer: + return 0 + else: + return b'' else: raise @@ -269,7 +267,6 @@ class SSLSocket(socket): while True: try: v = self.read(nbytes, buffer) - sys.stdout.flush() return v except SSLError as x: if x.args[0] == SSL_ERROR_WANT_READ: @@ -302,9 +299,7 @@ class SSLSocket(socket): def _real_close(self): self._sslobj = None # self._closed = True - if self._base: - self._base.close() - socket.close(self) + socket._real_close(self) def do_handshake(self, block=False): """Perform a TLS/SSL handshake.""" @@ -329,8 +324,12 @@ class SSLSocket(socket): self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile, self.cert_reqs, self.ssl_version, self.ca_certs) - if self.do_handshake_on_connect: - self.do_handshake() + try: + if self.do_handshake_on_connect: + self.do_handshake() + except: + self._sslobj = None + raise def accept(self): """Accepts a new connection from a remote client, and returns @@ -348,10 +347,11 @@ class SSLSocket(socket): self.do_handshake_on_connect), addr) - def __del__(self): + # sys.stderr.write("__del__ on %s\n" % repr(self)) self._real_close() + def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 18df3f4422d..81943a5c99c 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -13,6 +13,7 @@ import pprint import urllib, urlparse import shutil import traceback +import asyncore from BaseHTTPServer import HTTPServer from SimpleHTTPServer import SimpleHTTPRequestHandler @@ -79,27 +80,6 @@ class BasicTests(unittest.TestCase): class NetworkedTests(unittest.TestCase): - def testFetchServerCert(self): - - pem = ssl.get_server_certificate(("svn.python.org", 443)) - if not pem: - raise test_support.TestFailed("No server certificate on svn.python.org:443!") - - try: - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) - except ssl.SSLError as x: - #should fail - if test_support.verbose: - sys.stdout.write("%s\n" % x) - else: - raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem) - - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) - if not pem: - raise test_support.TestFailed("No server certificate on svn.python.org:443!") - if test_support.verbose: - sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) - def testConnect(self): s = ssl.wrap_socket(socket.socket(socket.AF_INET), @@ -155,6 +135,29 @@ class NetworkedTests(unittest.TestCase): if test_support.verbose: sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) + def testFetchServerCert(self): + + pem = ssl.get_server_certificate(("svn.python.org", 443)) + if not pem: + raise test_support.TestFailed("No server certificate on svn.python.org:443!") + + return + + try: + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) + except ssl.SSLError as x: + #should fail + if test_support.verbose: + sys.stdout.write("%s\n" % x) + else: + raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem) + + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + if not pem: + raise test_support.TestFailed("No server certificate on svn.python.org:443!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) + try: import threading @@ -333,7 +336,9 @@ else: def stop (self): self.active = False - class AsyncoreHTTPSServer(threading.Thread): + class OurHTTPSServer(threading.Thread): + + # This one's based on HTTPServer, which is based on SocketServer class HTTPSServer(HTTPServer): @@ -463,6 +468,92 @@ else: self.server.server_close() + class AsyncoreEchoServer(threading.Thread): + + # this one's based on asyncore.dispatcher + + class EchoServer (asyncore.dispatcher): + + class ConnectionHandler (asyncore.dispatcher_with_send): + + def __init__(self, conn, certfile): + self.socket = ssl.wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + asyncore.dispatcher_with_send.__init__(self, self.socket) + # now we have to do the handshake + # we'll just do it the easy way, and block the connection + # till it's finished. If we were doing it right, we'd + # do this in multiple calls to handle_read... + self.do_handshake(block=True) + + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True + + def handle_read(self): + data = self.recv(1024) + if test_support.verbose: + sys.stdout.write(" server: read %s from client\n" % repr(data)) + if not data: + self.close() + else: + self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict')) + + def handle_close(self): + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % self.socket) + + def handle_error(self): + raise + + def __init__(self, port, certfile): + self.port = port + self.certfile = certfile + asyncore.dispatcher.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.bind(('', port)) + self.listen(5) + + def handle_accept(self): + sock_obj, addr = self.accept() + if test_support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) + + def handle_error(self): + raise + + def __init__(self, port, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(port, certfile) + threading.Thread.__init__(self) + self.setDaemon(True) + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start (self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run (self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + try: + asyncore.loop(1) + except: + pass + + def stop (self): + self.active = False + self.server.close() + def badCertTest (certfile): server = ThreadedEchoServer(TESTPORT, CERTFILE, certreqs=ssl.CERT_REQUIRED, @@ -509,6 +600,7 @@ else: client_protocol = protocol try: s = ssl.wrap_socket(socket.socket(), + server_side=False, certfile=client_certfile, ca_certs=cacertsfile, cert_reqs=certreqs, @@ -811,11 +903,9 @@ else: server.stop() server.join() - class AsyncoreTests(unittest.TestCase): + def testSocketServer(self): - def testAsyncore(self): - - server = AsyncoreHTTPSServer(TESTPORT, CERTFILE) + server = OurHTTPSServer(TESTPORT, CERTFILE) flag = threading.Event() server.start(flag) # wait for it to start @@ -853,6 +943,47 @@ else: server.stop() server.join() + def testAsyncoreServer(self): + + if test_support.verbose: + sys.stdout.write("\n") + + indata="FOO\n" + server = AsyncoreEchoServer(TESTPORT, CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + s = ssl.wrap_socket(socket.socket()) + s.connect(('127.0.0.1', TESTPORT)) + except ssl.SSLError as x: + raise test_support.TestFailed("Unexpected SSL error: " + str(x)) + except Exception as x: + raise test_support.TestFailed("Unexpected exception: " + str(x)) + else: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.sendall(indata.encode('ASCII', 'strict')) + outdata = s.recv() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + outdata = str(outdata, 'ASCII', 'strict') + if outdata != indata.lower(): + raise test_support.TestFailed( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (repr(outdata[:min(len(outdata),20)]), len(outdata), + repr(indata[:min(len(indata),20)].lower()), len(indata))) + s.write("over\n".encode("ASCII", "strict")) + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + server.join() + def findtestsocket(start, end): def testbind(i): @@ -900,7 +1031,6 @@ def test_main(verbose=False): thread_info = test_support.threading_setup() if thread_info and test_support.is_resource_enabled('network'): tests.append(ThreadedTests) - tests.append(AsyncoreTests) test_support.run_unittest(*tests) diff --git a/Modules/_ssl.c b/Modules/_ssl.c index bd3f17262ae..7ab229716f3 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -46,6 +46,7 @@ enum py_ssl_error { PY_SSL_ERROR_WANT_CONNECT, /* start of non ssl.h errorcodes */ PY_SSL_ERROR_EOF, /* special case of SSL_ERROR_SYSCALL */ + PY_SSL_ERROR_NO_SOCKET, /* socket has been GC'd */ PY_SSL_ERROR_INVALID_ERROR_CODE }; @@ -111,7 +112,7 @@ static unsigned int _ssl_locks_count = 0; typedef struct { PyObject_HEAD - PySocketSockObject *Socket; /* Socket on which we're layered */ + PyObject *Socket; /* weakref to socket on which we're layered */ SSL_CTX* ctx; SSL* ssl; X509* peer_cert; @@ -188,13 +189,15 @@ PySSL_SetError(PySSLObject *obj, int ret, char *filename, int lineno) { unsigned long e = ERR_get_error(); if (e == 0) { - if (ret == 0 || !obj->Socket) { + PySocketSockObject *s + = (PySocketSockObject *) PyWeakref_GetObject(obj->Socket); + if (ret == 0 || (((PyObject *)s) == Py_None)) { p = PY_SSL_ERROR_EOF; errstr = "EOF occurred in violation of protocol"; } else if (ret == -1) { /* underlying BIO reported an I/O error */ - return obj->Socket->errorhandler(); + return s->errorhandler(); } else { /* possible? */ p = PY_SSL_ERROR_SYSCALL; errstr = "Some I/O error occurred"; @@ -383,8 +386,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, SSL_set_accept_state(self->ssl); PySSL_END_ALLOW_THREADS - self->Socket = Sock; - Py_INCREF(self->Socket); + self->Socket = PyWeakref_NewRef((PyObject *) Sock, Py_None); return self; fail: if (errstr) @@ -442,6 +444,14 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self) /* XXX If SSL_do_handshake() returns 0, it's also a failure. */ sockstate = 0; do { + PySocketSockObject *sock + = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + PySSL_BEGIN_ALLOW_THREADS ret = SSL_do_handshake(self->ssl); err = SSL_get_error(self->ssl, ret); @@ -450,9 +460,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self) return NULL; } if (err == SSL_ERROR_WANT_READ) { - sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); + sockstate = check_socket_and_wait_for_timeout(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + sockstate = check_socket_and_wait_for_timeout(sock, 1); } else { sockstate = SOCKET_OPERATION_OK; } @@ -1140,16 +1150,24 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) int sockstate; int err; int nonblocking; + PySocketSockObject *sock + = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); + + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } if (!PyArg_ParseTuple(args, "y#:write", &data, &count)) return NULL; /* just in case the blocking state of the socket has been changed */ - nonblocking = (self->Socket->sock_timeout >= 0.0); + nonblocking = (sock->sock_timeout >= 0.0); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); - sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + sockstate = check_socket_and_wait_for_timeout(sock, 1); if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The write operation timed out"); @@ -1174,10 +1192,10 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) } if (err == SSL_ERROR_WANT_READ) { sockstate = - check_socket_and_wait_for_timeout(self->Socket, 0); + check_socket_and_wait_for_timeout(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { sockstate = - check_socket_and_wait_for_timeout(self->Socket, 1); + check_socket_and_wait_for_timeout(sock, 1); } else { sockstate = SOCKET_OPERATION_OK; } @@ -1233,10 +1251,17 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) int sockstate; int err; int nonblocking; + PySocketSockObject *sock + = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); + + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count)) return NULL; - if ((buf == NULL) || (buf == Py_None)) { if (!(buf = PyBytes_FromStringAndSize((char *) 0, len))) return NULL; @@ -1254,7 +1279,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) } /* just in case the blocking state of the socket has been changed */ - nonblocking = (self->Socket->sock_timeout >= 0.0); + nonblocking = (sock->sock_timeout >= 0.0); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); @@ -1264,7 +1289,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) PySSL_END_ALLOW_THREADS if (!count) { - sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); + sockstate = check_socket_and_wait_for_timeout(sock, 0); if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The read operation timed out"); @@ -1299,10 +1324,10 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) } if (err == SSL_ERROR_WANT_READ) { sockstate = - check_socket_and_wait_for_timeout(self->Socket, 0); + check_socket_and_wait_for_timeout(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { sockstate = - check_socket_and_wait_for_timeout(self->Socket, 1); + check_socket_and_wait_for_timeout(sock, 1); } else if ((err == SSL_ERROR_ZERO_RETURN) && (SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN))