update to fix leak in SSL code

This commit is contained in:
Bill Janssen 2007-12-14 22:08:56 +00:00
parent 517b9ddda2
commit 54cc54c1fe
4 changed files with 225 additions and 68 deletions

View File

@ -174,11 +174,13 @@ class socket(_socket.socket):
if self._closed:
self.close()
def _real_close(self):
_socket.socket.close(self)
def close(self):
self._closed = True
if self._io_refs <= 0:
_socket.socket.close(self)
self._real_close()
def fromfd(fd, family, type, proto=0):
""" fromfd(fd, family, type[, proto]) -> socket object

View File

@ -80,6 +80,7 @@ from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
from socket import dup as _dup
import base64 # for DER-to-PEM translation
import traceback
class SSLSocket(socket):
@ -94,16 +95,13 @@ class SSLSocket(socket):
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True):
self._base = None
if sock is not None:
# copied this code from socket.accept()
fd = sock.fileno()
nfd = _dup(fd)
socket.__init__(self, family=sock.family, type=sock.type,
proto=sock.proto, fileno=nfd)
socket.__init__(self,
family=sock.family,
type=sock.type,
proto=sock.proto,
fileno=_dup(sock.fileno()))
sock.close()
sock = None
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
@ -136,10 +134,6 @@ class SSLSocket(socket):
self.close()
raise x
if sock and (self.fileno() != sock.fileno()):
self._base = sock
else:
self._base = None
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
@ -156,19 +150,23 @@ class SSLSocket(socket):
# raise an exception here if you wish to check for spurious closes
pass
def read(self, len=None, buffer=None):
def read(self, len=0, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
try:
if buffer:
return self._sslobj.read(buffer, len)
v = self._sslobj.read(buffer, len)
else:
return self._sslobj.read(len or 1024)
v = self._sslobj.read(len or 1024)
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return b''
if buffer:
return 0
else:
return b''
else:
raise
@ -269,7 +267,6 @@ class SSLSocket(socket):
while True:
try:
v = self.read(nbytes, buffer)
sys.stdout.flush()
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
@ -302,9 +299,7 @@ class SSLSocket(socket):
def _real_close(self):
self._sslobj = None
# self._closed = True
if self._base:
self._base.close()
socket.close(self)
socket._real_close(self)
def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake."""
@ -329,8 +324,12 @@ class SSLSocket(socket):
self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
try:
if self.do_handshake_on_connect:
self.do_handshake()
except:
self._sslobj = None
raise
def accept(self):
"""Accepts a new connection from a remote client, and returns
@ -348,10 +347,11 @@ class SSLSocket(socket):
self.do_handshake_on_connect),
addr)
def __del__(self):
# sys.stderr.write("__del__ on %s\n" % repr(self))
self._real_close()
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,

View File

@ -13,6 +13,7 @@ import pprint
import urllib, urlparse
import shutil
import traceback
import asyncore
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
@ -79,27 +80,6 @@ class BasicTests(unittest.TestCase):
class NetworkedTests(unittest.TestCase):
def testFetchServerCert(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
try:
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
if test_support.verbose:
sys.stdout.write("%s\n" % x)
else:
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
def testConnect(self):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@ -155,6 +135,29 @@ class NetworkedTests(unittest.TestCase):
if test_support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def testFetchServerCert(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
return
try:
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
if test_support.verbose:
sys.stdout.write("%s\n" % x)
else:
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
try:
import threading
@ -333,7 +336,9 @@ else:
def stop (self):
self.active = False
class AsyncoreHTTPSServer(threading.Thread):
class OurHTTPSServer(threading.Thread):
# This one's based on HTTPServer, which is based on SocketServer
class HTTPSServer(HTTPServer):
@ -463,6 +468,92 @@ else:
self.server.server_close()
class AsyncoreEchoServer(threading.Thread):
# this one's based on asyncore.dispatcher
class EchoServer (asyncore.dispatcher):
class ConnectionHandler (asyncore.dispatcher_with_send):
def __init__(self, conn, certfile):
self.socket = ssl.wrap_socket(conn, server_side=True,
certfile=certfile,
do_handshake_on_connect=False)
asyncore.dispatcher_with_send.__init__(self, self.socket)
# now we have to do the handshake
# we'll just do it the easy way, and block the connection
# till it's finished. If we were doing it right, we'd
# do this in multiple calls to handle_read...
self.do_handshake(block=True)
def readable(self):
if isinstance(self.socket, ssl.SSLSocket):
while self.socket.pending() > 0:
self.handle_read_event()
return True
def handle_read(self):
data = self.recv(1024)
if test_support.verbose:
sys.stdout.write(" server: read %s from client\n" % repr(data))
if not data:
self.close()
else:
self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict'))
def handle_close(self):
if test_support.verbose:
sys.stdout.write(" server: closed connection %s\n" % self.socket)
def handle_error(self):
raise
def __init__(self, port, certfile):
self.port = port
self.certfile = certfile
asyncore.dispatcher.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.bind(('', port))
self.listen(5)
def handle_accept(self):
sock_obj, addr = self.accept()
if test_support.verbose:
sys.stdout.write(" server: new connection from %s:%s\n" %addr)
self.ConnectionHandler(sock_obj, self.certfile)
def handle_error(self):
raise
def __init__(self, port, certfile):
self.flag = None
self.active = False
self.server = self.EchoServer(port, certfile)
threading.Thread.__init__(self)
self.setDaemon(True)
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run (self):
self.active = True
if self.flag:
self.flag.set()
while self.active:
try:
asyncore.loop(1)
except:
pass
def stop (self):
self.active = False
self.server.close()
def badCertTest (certfile):
server = ThreadedEchoServer(TESTPORT, CERTFILE,
certreqs=ssl.CERT_REQUIRED,
@ -509,6 +600,7 @@ else:
client_protocol = protocol
try:
s = ssl.wrap_socket(socket.socket(),
server_side=False,
certfile=client_certfile,
ca_certs=cacertsfile,
cert_reqs=certreqs,
@ -811,11 +903,9 @@ else:
server.stop()
server.join()
class AsyncoreTests(unittest.TestCase):
def testSocketServer(self):
def testAsyncore(self):
server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
server = OurHTTPSServer(TESTPORT, CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
@ -853,6 +943,47 @@ else:
server.stop()
server.join()
def testAsyncoreServer(self):
if test_support.verbose:
sys.stdout.write("\n")
indata="FOO\n"
server = AsyncoreEchoServer(TESTPORT, CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', TESTPORT))
except ssl.SSLError as x:
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
except Exception as x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(indata)))
s.sendall(indata.encode('ASCII', 'strict'))
outdata = s.recv()
if test_support.verbose:
sys.stdout.write(" client: read %s\n" % repr(outdata))
outdata = str(outdata, 'ASCII', 'strict')
if outdata != indata.lower():
raise test_support.TestFailed(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (repr(outdata[:min(len(outdata),20)]), len(outdata),
repr(indata[:min(len(indata),20)].lower()), len(indata)))
s.write("over\n".encode("ASCII", "strict"))
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
finally:
server.stop()
server.join()
def findtestsocket(start, end):
def testbind(i):
@ -900,7 +1031,6 @@ def test_main(verbose=False):
thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'):
tests.append(ThreadedTests)
tests.append(AsyncoreTests)
test_support.run_unittest(*tests)

View File

@ -46,6 +46,7 @@ enum py_ssl_error {
PY_SSL_ERROR_WANT_CONNECT,
/* start of non ssl.h errorcodes */
PY_SSL_ERROR_EOF, /* special case of SSL_ERROR_SYSCALL */
PY_SSL_ERROR_NO_SOCKET, /* socket has been GC'd */
PY_SSL_ERROR_INVALID_ERROR_CODE
};
@ -111,7 +112,7 @@ static unsigned int _ssl_locks_count = 0;
typedef struct {
PyObject_HEAD
PySocketSockObject *Socket; /* Socket on which we're layered */
PyObject *Socket; /* weakref to socket on which we're layered */
SSL_CTX* ctx;
SSL* ssl;
X509* peer_cert;
@ -188,13 +189,15 @@ PySSL_SetError(PySSLObject *obj, int ret, char *filename, int lineno)
{
unsigned long e = ERR_get_error();
if (e == 0) {
if (ret == 0 || !obj->Socket) {
PySocketSockObject *s
= (PySocketSockObject *) PyWeakref_GetObject(obj->Socket);
if (ret == 0 || (((PyObject *)s) == Py_None)) {
p = PY_SSL_ERROR_EOF;
errstr =
"EOF occurred in violation of protocol";
} else if (ret == -1) {
/* underlying BIO reported an I/O error */
return obj->Socket->errorhandler();
return s->errorhandler();
} else { /* possible? */
p = PY_SSL_ERROR_SYSCALL;
errstr = "Some I/O error occurred";
@ -383,8 +386,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
self->Socket = Sock;
Py_INCREF(self->Socket);
self->Socket = PyWeakref_NewRef((PyObject *) Sock, Py_None);
return self;
fail:
if (errstr)
@ -442,6 +444,14 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
sockstate = 0;
do {
PySocketSockObject *sock
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
PySSL_BEGIN_ALLOW_THREADS
ret = SSL_do_handshake(self->ssl);
err = SSL_get_error(self->ssl, ret);
@ -450,9 +460,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
return NULL;
}
if (err == SSL_ERROR_WANT_READ) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
sockstate = check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
sockstate = check_socket_and_wait_for_timeout(sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
@ -1140,16 +1150,24 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
int sockstate;
int err;
int nonblocking;
PySocketSockObject *sock
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
return NULL;
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
sockstate = check_socket_and_wait_for_timeout(sock, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
"The write operation timed out");
@ -1174,10 +1192,10 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
}
if (err == SSL_ERROR_WANT_READ) {
sockstate =
check_socket_and_wait_for_timeout(self->Socket, 0);
check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate =
check_socket_and_wait_for_timeout(self->Socket, 1);
check_socket_and_wait_for_timeout(sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
@ -1233,10 +1251,17 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
int sockstate;
int err;
int nonblocking;
PySocketSockObject *sock
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
return NULL;
if ((buf == NULL) || (buf == Py_None)) {
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
return NULL;
@ -1254,7 +1279,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
}
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
@ -1264,7 +1289,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
PySSL_END_ALLOW_THREADS
if (!count) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
sockstate = check_socket_and_wait_for_timeout(sock, 0);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
"The read operation timed out");
@ -1299,10 +1324,10 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
}
if (err == SSL_ERROR_WANT_READ) {
sockstate =
check_socket_and_wait_for_timeout(self->Socket, 0);
check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate =
check_socket_and_wait_for_timeout(self->Socket, 1);
check_socket_and_wait_for_timeout(sock, 1);
} else if ((err == SSL_ERROR_ZERO_RETURN) &&
(SSL_get_shutdown(self->ssl) ==
SSL_RECEIVED_SHUTDOWN))