From 934b16d0c2c4dcaa15051e4e7d61543f9f64fa82 Mon Sep 17 00:00:00 2001 From: Bill Janssen Date: Sat, 28 Jun 2008 22:19:33 +0000 Subject: [PATCH] various SSL fixes; issues 1251, 3162, 3212 --- Doc/library/ssl.rst | 34 +++- Lib/ssl.py | 361 ++++++++++++----------------------------- Lib/test/test_ssl.py | 246 ++++++++++++++++++++++++---- Lib/test/wrongcert.pem | 32 ++++ Modules/_ssl.c | 203 ++++++++++++++++------- 5 files changed, 528 insertions(+), 348 deletions(-) create mode 100644 Lib/test/wrongcert.pem diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst index fb41091cae6..a41c6ea6217 100644 --- a/Doc/library/ssl.rst +++ b/Doc/library/ssl.rst @@ -54,7 +54,7 @@ Functions, Constants, and Exceptions network connection. This error is a subtype of :exc:`socket.error`, which in turn is a subtype of :exc:`IOError`. -.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None) +.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None, do_handshake_on_connect=True, suppress_ragged_eofs=True) Takes an instance ``sock`` of :class:`socket.socket`, and returns an instance of :class:`ssl.SSLSocket`, a subtype of :class:`socket.socket`, which wraps the underlying socket in an SSL context. @@ -122,6 +122,18 @@ Functions, Constants, and Exceptions In some older versions of OpenSSL (for instance, 0.9.7l on OS X 10.4), an SSLv2 client could not connect to an SSLv23 server. + The parameter ``do_handshake_on_connect`` specifies whether to do the SSL + handshake automatically after doing a :meth:`socket.connect`, or whether the + application program will call it explicitly, by invoking the :meth:`SSLSocket.do_handshake` + method. Calling :meth:`SSLSocket.do_handshake` explicitly gives the program control over + the blocking behavior of the socket I/O involved in the handshake. + + The parameter ``suppress_ragged_eofs`` specifies how the :meth:`SSLSocket.read` + method should signal unexpected EOF from the other end of the connection. If specified + as :const:`True` (the default), it returns a normal EOF in response to unexpected + EOF errors raised from the underlying socket; if :const:`False`, it will raise + the exceptions back to the caller. + .. function:: RAND_status() Returns True if the SSL pseudo-random number generator has been @@ -290,6 +302,25 @@ SSLSocket Objects number of secret bits being used. If no connection has been established, returns ``None``. +.. method:: SSLSocket.do_handshake() + + Perform a TLS/SSL handshake. If this is used with a non-blocking socket, + it may raise :exc:`SSLError` with an ``arg[0]`` of :const:`SSL_ERROR_WANT_READ` + or :const:`SSL_ERROR_WANT_WRITE`, in which case it must be called again until it + completes successfully. For example, to simulate the behavior of a blocking socket, + one might write:: + + while True: + try: + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise .. index:: single: certificates @@ -367,6 +398,7 @@ certificate, you need to provide a "CA certs" file, filled with the certificate chains for each issuer you are willing to trust. Again, this file just contains these chains concatenated together. For validation, Python will use the first chain it finds in the file which matches. + Some "standard" root certificates are available from various certification authorities: `CACert.org `_, diff --git a/Lib/ssl.py b/Lib/ssl.py index 24502e44e0d..e45e16bb14a 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -74,7 +74,7 @@ from _ssl import \ SSL_ERROR_EOF, \ SSL_ERROR_INVALID_ERROR_CODE -from socket import socket +from socket import socket, _fileobject from socket import getnameinfo as _getnameinfo import base64 # for DER-to-PEM translation @@ -86,8 +86,16 @@ class SSLSocket (socket): def __init__(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None): + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True): socket.__init__(self, _sock=sock._sock) + # the initializer for socket trashes the methods (tsk, tsk), so... + self.send = lambda x, flags=0: SSLSocket.send(self, x, flags) + self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags) + self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags) + self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags) + if certfile and not keyfile: keyfile = certfile # see if it's connected @@ -101,18 +109,34 @@ class SSLSocket (socket): self._sslobj = _ssl.sslwrap(self._sock, server_side, keyfile, certfile, cert_reqs, ssl_version, ca_certs) + if do_handshake_on_connect: + timeout = self.gettimeout() + try: + self.settimeout(None) + self.do_handshake() + finally: + self.settimeout(timeout) self.keyfile = keyfile self.certfile = certfile self.cert_reqs = cert_reqs self.ssl_version = ssl_version self.ca_certs = ca_certs + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 def read(self, len=1024): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" - return self._sslobj.read(len) + try: + return self._sslobj.read(len) + except SSLError, x: + if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + return '' + else: + raise def write(self, data): @@ -143,16 +167,27 @@ class SSLSocket (socket): raise ValueError( "non-zero flags not allowed in calls to send() on %s" % self.__class__) - return self._sslobj.write(data) + while True: + try: + v = self._sslobj.write(data) + except SSLError, x: + if x.args[0] == SSL_ERROR_WANT_READ: + return 0 + elif x.args[0] == SSL_ERROR_WANT_WRITE: + return 0 + else: + raise + else: + return v else: return socket.send(self, data, flags) - def send_to (self, data, addr, flags=0): + def sendto (self, data, addr, flags=0): if self._sslobj: - raise ValueError("send_to not allowed on instances of %s" % + raise ValueError("sendto not allowed on instances of %s" % self.__class__) else: - return socket.send_to(self, data, addr, flags) + return socket.sendto(self, data, addr, flags) def sendall (self, data, flags=0): if self._sslobj: @@ -160,7 +195,12 @@ class SSLSocket (socket): raise ValueError( "non-zero flags not allowed in calls to sendall() on %s" % self.__class__) - return self._sslobj.write(data) + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:]) + count += v + return amount else: return socket.sendall(self, data, flags) @@ -170,25 +210,51 @@ class SSLSocket (socket): raise ValueError( "non-zero flags not allowed in calls to sendall() on %s" % self.__class__) - return self._sslobj.read(data, buflen) + while True: + try: + return self.read(buflen) + except SSLError, x: + if x.args[0] == SSL_ERROR_WANT_READ: + continue + else: + raise x else: return socket.recv(self, buflen, flags) - def recv_from (self, addr, buflen=1024, flags=0): + def recvfrom (self, addr, buflen=1024, flags=0): if self._sslobj: - raise ValueError("recv_from not allowed on instances of %s" % + raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) else: - return socket.recv_from(self, addr, buflen, flags) + return socket.recvfrom(self, addr, buflen, flags) - def shutdown(self, how): + def pending (self): + if self._sslobj: + return self._sslobj.pending() + else: + return 0 + + def shutdown (self, how): self._sslobj = None socket.shutdown(self, how) - def close(self): + def close (self): self._sslobj = None socket.close(self) + def close (self): + if self._makefile_refs < 1: + self._sslobj = None + socket.close(self) + else: + self._makefile_refs -= 1 + + def do_handshake (self): + + """Perform a TLS/SSL handshake.""" + + self._sslobj.do_handshake() + def connect(self, addr): """Connects to remote ADDR, and then wraps the connection in @@ -202,6 +268,8 @@ class SSLSocket (socket): self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, self.cert_reqs, self.ssl_version, self.ca_certs) + if self.do_handshake_on_connect: + self.do_handshake() def accept(self): @@ -210,260 +278,39 @@ class SSLSocket (socket): SSL channel, and the address of the remote client.""" newsock, addr = socket.accept(self) - return (SSLSocket(newsock, True, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs), addr) - + return (SSLSocket(newsock, + keyfile=self.keyfile, + certfile=self.certfile, + server_side=True, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ca_certs=self.ca_certs, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs), + addr) def makefile(self, mode='r', bufsize=-1): """Ouch. Need to make and return a file-like object that works with the SSL connection.""" - if self._sslobj: - return SSLFileStream(self._sslobj, mode, bufsize) - else: - return socket.makefile(self, mode, bufsize) - - -class SSLFileStream: - - """A class to simulate a file stream on top of a socket. - Most of this is just lifted from the socket module, and - adjusted to work with an SSL stream instead of a socket.""" - - - default_bufsize = 8192 - name = "" - - __slots__ = ["mode", "bufsize", "softspace", - # "closed" is a property, see below - "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", - "_close", "_fileno"] - - def __init__(self, sslobj, mode='rb', bufsize=-1, close=False): - self._sslobj = sslobj - self.mode = mode # Not actually used in this version - if bufsize < 0: - bufsize = self.default_bufsize - self.bufsize = bufsize - self.softspace = False - if bufsize == 0: - self._rbufsize = 1 - elif bufsize == 1: - self._rbufsize = self.default_bufsize - else: - self._rbufsize = bufsize - self._wbufsize = bufsize - self._rbuf = "" # A string - self._wbuf = [] # A list of strings - self._close = close - self._fileno = -1 - - def _getclosed(self): - return self._sslobj is None - closed = property(_getclosed, doc="True if the file is closed") - - def fileno(self): - return self._fileno - - def close(self): - try: - if self._sslobj: - self.flush() - finally: - if self._close and self._sslobj: - self._sslobj.close() - self._sslobj = None - - def __del__(self): - try: - self.close() - except: - # close() may fail if __init__ didn't complete - pass - - def flush(self): - if self._wbuf: - buffer = "".join(self._wbuf) - self._wbuf = [] - count = 0 - while (count < len(buffer)): - written = self._sslobj.write(buffer) - count += written - buffer = buffer[written:] - - def write(self, data): - data = str(data) # XXX Should really reject non-string non-buffers - if not data: - return - self._wbuf.append(data) - if (self._wbufsize == 0 or - self._wbufsize == 1 and '\n' in data or - self._get_wbuf_len() >= self._wbufsize): - self.flush() - - def writelines(self, list): - # XXX We could do better here for very long lists - # XXX Should really reject non-string non-buffers - self._wbuf.extend(filter(None, map(str, list))) - if (self._wbufsize <= 1 or - self._get_wbuf_len() >= self._wbufsize): - self.flush() - - def _get_wbuf_len(self): - buf_len = 0 - for x in self._wbuf: - buf_len += len(x) - return buf_len - - def read(self, size=-1): - data = self._rbuf - if size < 0: - # Read until EOF - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - if self._rbufsize <= 1: - recv_size = self.default_bufsize - else: - recv_size = self._rbufsize - while True: - data = self._sslobj.read(recv_size) - if not data: - break - buffers.append(data) - return "".join(buffers) - else: - # Read until size bytes or EOF seen, whichever comes first - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - left = size - buf_len - recv_size = max(self._rbufsize, left) - data = self._sslobj.read(recv_size) - if not data: - break - buffers.append(data) - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) - - def readline(self, size=-1): - data = self._rbuf - if size < 0: - # Read until \n or EOF, whichever comes first - if self._rbufsize <= 1: - # Speed up unbuffered case - assert data == "" - buffers = [] - while data != "\n": - data = self._sslobj.read(1) - if not data: - break - buffers.append(data) - return "".join(buffers) - nl = data.find('\n') - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self._sslobj.read(self._rbufsize) - if not data: - break - buffers.append(data) - nl = data.find('\n') - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - return "".join(buffers) - else: - # Read until size bytes or \n or EOF seen, whichever comes first - nl = data.find('\n', 0, size) - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self._sslobj.read(self._rbufsize) - if not data: - break - buffers.append(data) - left = size - buf_len - nl = data.find('\n', 0, left) - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) - - def readlines(self, sizehint=0): - total = 0 - list = [] - while True: - line = self.readline() - if not line: - break - list.append(line) - total += len(line) - if sizehint and total >= sizehint: - break - return list - - # Iterator protocols - - def __iter__(self): - return self - - def next(self): - line = self.readline() - if not line: - raise StopIteration - return line - + self._makefile_refs += 1 + return _fileobject(self, mode, bufsize) def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None): + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True): return SSLSocket(sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, - ssl_version=ssl_version, ca_certs=ca_certs) + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs) + # some utility functions @@ -549,5 +396,7 @@ def sslwrap_simple (sock, keyfile=None, certfile=None): for compability with Python 2.5 and earlier. Will disappear in Python 3.0.""" - return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, - PROTOCOL_SSLv23, None) + ssl_sock = _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, + PROTOCOL_SSLv23, None) + ssl_sock.do_handshake() + return ssl_sock diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index eb4d00ca5f3..d786154c809 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -3,7 +3,9 @@ import sys import unittest from test import test_support +import asyncore import socket +import select import errno import subprocess import time @@ -97,8 +99,7 @@ class BasicTests(unittest.TestCase): if (d1 != d2): raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") - -class NetworkTests(unittest.TestCase): +class NetworkedTests(unittest.TestCase): def testConnect(self): s = ssl.wrap_socket(socket.socket(socket.AF_INET), @@ -130,6 +131,31 @@ class NetworkTests(unittest.TestCase): finally: s.close() + + def testNonBlockingHandshake(self): + s = socket.socket(socket.AF_INET) + s.connect(("svn.python.org", 443)) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + s.close() + if test_support.verbose: + sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) + def testFetchServerCert(self): pem = ssl.get_server_certificate(("svn.python.org", 443)) @@ -176,6 +202,18 @@ else: threading.Thread.__init__(self) self.setDaemon(True) + def show_conn_details(self): + if self.server.certreqs == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if test_support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + def wrap_conn (self): try: self.sslconn = ssl.wrap_socket(self.sock, server_side=True, @@ -187,6 +225,7 @@ else: if self.server.chatty: handle_error("\n server: bad connection attempt from " + str(self.sock.getpeername()) + ":\n") + self.close() if not self.server.expect_bad_connects: # here, we want to stop the server, because this shouldn't # happen in the context of our test case @@ -197,16 +236,6 @@ else: return False else: - if self.server.certreqs == ssl.CERT_REQUIRED: - cert = self.sslconn.getpeercert() - if test_support.verbose and self.server.chatty: - sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") - cert_binary = self.sslconn.getpeercert(True) - if test_support.verbose and self.server.chatty: - sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") - cipher = self.sslconn.cipher() - if test_support.verbose and self.server.chatty: - sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") return True def read(self): @@ -225,13 +254,16 @@ else: if self.sslconn: self.sslconn.close() else: - self.sock.close() + self.sock._sock.close() def run (self): self.running = True if not self.server.starttls_server: - if not self.wrap_conn(): + if isinstance(self.sock, ssl.SSLSocket): + self.sslconn = self.sock + elif not self.wrap_conn(): return + self.show_conn_details() while self.running: try: msg = self.read() @@ -270,7 +302,9 @@ else: def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None, expect_bad_connects=False, - chatty=True, connectionchatty=False, starttls_server=False): + chatty=True, connectionchatty=False, starttls_server=False, + wrap_accepting_socket=False): + if ssl_version is None: ssl_version = ssl.PROTOCOL_TLSv1 if certreqs is None: @@ -284,8 +318,16 @@ else: self.connectionchatty = connectionchatty self.starttls_server = starttls_server self.sock = socket.socket() - self.port = test_support.bind_port(self.sock) self.flag = None + if wrap_accepting_socket: + self.sock = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.certificate, + cert_reqs = self.certreqs, + ca_certs = self.cacerts, + ssl_version = self.protocol) + if test_support.verbose and self.chatty: + sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock)) + self.port = test_support.bind_port(self.sock) self.active = False threading.Thread.__init__(self) self.setDaemon(False) @@ -316,13 +358,86 @@ else: except: if self.chatty: handle_error("Test server failure:\n") + self.sock.close() def stop (self): self.active = False - self.sock.close() + class AsyncoreEchoServer(threading.Thread): - class AsyncoreHTTPSServer(threading.Thread): + class EchoServer (asyncore.dispatcher): + + class ConnectionHandler (asyncore.dispatcher_with_send): + + def __init__(self, conn, certfile): + asyncore.dispatcher_with_send.__init__(self, conn) + self.socket = ssl.wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=True) + + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True + + def handle_read(self): + data = self.recv(1024) + self.send(data.lower()) + + def handle_close(self): + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % self.socket) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.certfile = certfile + asyncore.dispatcher.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = test_support.bind_port(self.socket) + self.listen(5) + + def handle_accept(self): + sock_obj, addr = self.accept() + if test_support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(certfile) + self.port = self.server.port + threading.Thread.__init__(self) + self.setDaemon(True) + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start (self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run (self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + try: + asyncore.loop(1) + except: + pass + + def stop (self): + self.active = False + self.server.close() + + class SocketServerHTTPSServer(threading.Thread): class HTTPSServer(HTTPServer): @@ -335,6 +450,12 @@ else: self.active_lock = threading.Lock() self.allow_reuse_address = True + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + def get_request (self): # override this to wrap socket with SSL sock, addr = self.socket.accept() @@ -421,8 +542,8 @@ else: # we override this to suppress logging unless "verbose" if test_support.verbose: - sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" % - (self.server.server_name, + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, self.server.server_port, self.request.cipher(), self.log_date_time_string(), @@ -440,9 +561,7 @@ else: self.setDaemon(True) def __str__(self): - return '<%s %s:%d>' % (self.__class__.__name__, - self.server.server_name, - self.server.server_port) + return "<%s %s>" % (self.__class__.__name__, self.server) def start (self, flag=None): self.flag = flag @@ -487,14 +606,16 @@ else: def serverParamsTest (certfile, protocol, certreqs, cacertsfile, client_certfile, client_protocol=None, indata="FOO\n", - chatty=True, connectionchatty=False): + chatty=True, connectionchatty=False, + wrap_accepting_socket=False): server = ThreadedEchoServer(certfile, certreqs=certreqs, ssl_version=protocol, cacerts=cacertsfile, chatty=chatty, - connectionchatty=connectionchatty) + connectionchatty=connectionchatty, + wrap_accepting_socket=wrap_accepting_socket) flag = threading.Event() server.start(flag) # wait for it to start @@ -572,7 +693,7 @@ else: ssl.get_protocol_name(server_protocol))) - class ConnectedTests(unittest.TestCase): + class ThreadedTests(unittest.TestCase): def testRudeShutdown(self): @@ -600,7 +721,7 @@ else: listener_gone.wait() try: ssl_sock = ssl.wrap_socket(s) - except socket.sslerror: + except IOError: pass else: raise test_support.TestFailed( @@ -680,6 +801,9 @@ else: def testMalformedCert(self): badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, "badcert.pem")) + def testWrongCert(self): + badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem")) def testMalformedKey(self): badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, "badkey.pem")) @@ -796,9 +920,9 @@ else: server.stop() server.join() - def testAsyncore(self): + def testSocketServer(self): - server = AsyncoreHTTPSServer(CERTFILE) + server = SocketServerHTTPSServer(CERTFILE) flag = threading.Event() server.start(flag) # wait for it to start @@ -810,8 +934,8 @@ else: d1 = open(CERTFILE, 'rb').read() d2 = '' # now fetch the same data from the HTTPS server - url = 'https://%s:%d/%s' % ( - HOST, server.port, os.path.split(CERTFILE)[1]) + url = 'https://127.0.0.1:%d/%s' % ( + server.port, os.path.split(CERTFILE)[1]) f = urllib.urlopen(url) dlen = f.info().getheader("content-length") if dlen and (int(dlen) > 0): @@ -834,6 +958,58 @@ else: server.stop() server.join() + def testWrappedAccept (self): + + if test_support.verbose: + sys.stdout.write("\n") + serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, + CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, + chatty=True, connectionchatty=True, + wrap_accepting_socket=True) + + + def testAsyncoreServer (self): + + indata = "TEST MESSAGE of mixed case\n" + + if test_support.verbose: + sys.stdout.write("\n") + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + try: + s = ssl.wrap_socket(socket.socket()) + s.connect(('127.0.0.1', server.port)) + except ssl.SSLError, x: + raise test_support.TestFailed("Unexpected SSL error: " + str(x)) + except Exception, x: + raise test_support.TestFailed("Unexpected exception: " + str(x)) + else: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise test_support.TestFailed( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + # wait for server thread to end + server.join() + def test_main(verbose=False): if skip_expected: @@ -850,15 +1026,19 @@ def test_main(verbose=False): not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)): raise test_support.TestFailed("Can't read certificate files!") + TESTPORT = test_support.find_unused_port() + if not TESTPORT: + raise test_support.TestFailed("Can't find open port to test servers on!") + tests = [BasicTests] if test_support.is_resource_enabled('network'): - tests.append(NetworkTests) + tests.append(NetworkedTests) if _have_threads: thread_info = test_support.threading_setup() if thread_info and test_support.is_resource_enabled('network'): - tests.append(ConnectedTests) + tests.append(ThreadedTests) test_support.run_unittest(*tests) diff --git a/Lib/test/wrongcert.pem b/Lib/test/wrongcert.pem new file mode 100644 index 00000000000..5f92f9bce76 --- /dev/null +++ b/Lib/test/wrongcert.pem @@ -0,0 +1,32 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH +FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T +f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB +AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq +1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW +7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg +SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe +19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg +ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666 +lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs +ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv +frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk +2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6 ++bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK +24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G +A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9 +OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt +U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ +fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E +usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+ +43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw +eSHj5jpC8iZKjCHBn+mAi4cQ514= +-----END CERTIFICATE----- diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 3f167b32caf..8fe72a5e450 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -2,14 +2,15 @@ SSL support based on patches by Brian E Gallew and Laszlo Kovacs. Re-worked a bit by Bill Janssen to add server-side support and - certificate decoding. + certificate decoding. Chris Stawarz contributed some non-blocking + patches. This module is imported by ssl.py. It should *not* be used directly. XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE? - XXX what about SSL_MODE_AUTO_RETRY + XXX what about SSL_MODE_AUTO_RETRY? */ #include "Python.h" @@ -265,8 +266,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, PySSLObject *self; char *errstr = NULL; int ret; - int err; - int sockstate; int verification_mode; self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */ @@ -388,57 +387,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, SSL_set_accept_state(self->ssl); PySSL_END_ALLOW_THREADS - /* Actually negotiate SSL connection */ - /* XXX If SSL_connect() returns 0, it's also a failure. */ - sockstate = 0; - do { - PySSL_BEGIN_ALLOW_THREADS - if (socket_type == PY_SSL_CLIENT) - ret = SSL_connect(self->ssl); - else - ret = SSL_accept(self->ssl); - err = SSL_get_error(self->ssl, ret); - PySSL_END_ALLOW_THREADS - if(PyErr_CheckSignals()) { - goto fail; - } - if (err == SSL_ERROR_WANT_READ) { - sockstate = check_socket_and_wait_for_timeout(Sock, 0); - } else if (err == SSL_ERROR_WANT_WRITE) { - sockstate = check_socket_and_wait_for_timeout(Sock, 1); - } else { - sockstate = SOCKET_OPERATION_OK; - } - if (sockstate == SOCKET_HAS_TIMED_OUT) { - PyErr_SetString(PySSLErrorObject, - ERRSTR("The connect operation timed out")); - goto fail; - } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { - PyErr_SetString(PySSLErrorObject, - ERRSTR("Underlying socket has been closed.")); - goto fail; - } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { - PyErr_SetString(PySSLErrorObject, - ERRSTR("Underlying socket too large for select().")); - goto fail; - } else if (sockstate == SOCKET_IS_NONBLOCKING) { - break; - } - } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); - if (ret < 1) { - PySSL_SetError(self, ret, __FILE__, __LINE__); - goto fail; - } - self->ssl->debug = 1; - - PySSL_BEGIN_ALLOW_THREADS - if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) { - X509_NAME_oneline(X509_get_subject_name(self->peer_cert), - self->server, X509_NAME_MAXLEN); - X509_NAME_oneline(X509_get_issuer_name(self->peer_cert), - self->issuer, X509_NAME_MAXLEN); - } - PySSL_END_ALLOW_THREADS self->Socket = Sock; Py_INCREF(self->Socket); return self; @@ -488,6 +436,65 @@ PyDoc_STRVAR(ssl_doc, /* SSL object methods */ +static PyObject *PySSL_SSLdo_handshake(PySSLObject *self) +{ + int ret; + int err; + int sockstate; + + /* Actually negotiate SSL connection */ + /* XXX If SSL_do_handshake() returns 0, it's also a failure. */ + sockstate = 0; + do { + PySSL_BEGIN_ALLOW_THREADS + ret = SSL_do_handshake(self->ssl); + err = SSL_get_error(self->ssl, ret); + PySSL_END_ALLOW_THREADS + if(PyErr_CheckSignals()) { + return NULL; + } + if (err == SSL_ERROR_WANT_READ) { + sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); + } else if (err == SSL_ERROR_WANT_WRITE) { + sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + } else { + sockstate = SOCKET_OPERATION_OK; + } + if (sockstate == SOCKET_HAS_TIMED_OUT) { + PyErr_SetString(PySSLErrorObject, + ERRSTR("The handshake operation timed out")); + return NULL; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(PySSLErrorObject, + ERRSTR("Underlying socket has been closed.")); + return NULL; + } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { + PyErr_SetString(PySSLErrorObject, + ERRSTR("Underlying socket too large for select().")); + return NULL; + } else if (sockstate == SOCKET_IS_NONBLOCKING) { + break; + } + } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); + if (ret < 1) + return PySSL_SetError(self, ret, __FILE__, __LINE__); + self->ssl->debug = 1; + + if (self->peer_cert) + X509_free (self->peer_cert); + PySSL_BEGIN_ALLOW_THREADS + if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) { + X509_NAME_oneline(X509_get_subject_name(self->peer_cert), + self->server, X509_NAME_MAXLEN); + X509_NAME_oneline(X509_get_issuer_name(self->peer_cert), + self->issuer, X509_NAME_MAXLEN); + } + PySSL_END_ALLOW_THREADS + + Py_INCREF(Py_None); + return Py_None; +} + static PyObject * PySSL_server(PySSLObject *self) { @@ -1127,7 +1134,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv); PySSL_END_ALLOW_THREADS +#ifdef HAVE_POLL normal_return: +#endif /* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise (when we are able to write or when there's something to read) */ return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK; @@ -1140,10 +1149,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) int count; int sockstate; int err; + int nonblocking; if (!PyArg_ParseTuple(args, "s#:write", &data, &count)) return NULL; + /* just in case the blocking state of the socket has been changed */ + nonblocking = (self->Socket->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, @@ -1200,6 +1215,25 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc, Writes the string s into the SSL object. Returns the number\n\ of bytes written."); +static PyObject *PySSL_SSLpending(PySSLObject *self) +{ + int count = 0; + + PySSL_BEGIN_ALLOW_THREADS + count = SSL_pending(self->ssl); + PySSL_END_ALLOW_THREADS + if (count < 0) + return PySSL_SetError(self, count, __FILE__, __LINE__); + else + return PyInt_FromLong(count); +} + +PyDoc_STRVAR(PySSL_SSLpending_doc, +"pending() -> count\n\ +\n\ +Returns the number of already decrypted bytes available for read,\n\ +pending on the connection.\n"); + static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) { PyObject *buf; @@ -1207,6 +1241,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) int len = 1024; int sockstate; int err; + int nonblocking; if (!PyArg_ParseTuple(args, "|i:read", &len)) return NULL; @@ -1214,6 +1249,11 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) if (!(buf = PyString_FromStringAndSize((char *) 0, len))) return NULL; + /* just in case the blocking state of the socket has been changed */ + nonblocking = (self->Socket->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + /* first check if there are bytes ready to be read */ PySSL_BEGIN_ALLOW_THREADS count = SSL_pending(self->ssl); @@ -1232,9 +1272,18 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) Py_DECREF(buf); return NULL; } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { - /* should contain a zero-length string */ - _PyString_Resize(&buf, 0); - return buf; + if (SSL_get_shutdown(self->ssl) != + SSL_RECEIVED_SHUTDOWN) + { + Py_DECREF(buf); + PyErr_SetString(PySSLErrorObject, + "Socket closed without SSL shutdown handshake"); + return NULL; + } else { + /* should contain a zero-length string */ + _PyString_Resize(&buf, 0); + return buf; + } } } do { @@ -1285,16 +1334,54 @@ PyDoc_STRVAR(PySSL_SSLread_doc, \n\ Read up to len bytes from the SSL socket."); +static PyObject *PySSL_SSLshutdown(PySSLObject *self) +{ + int err; + + /* Guard against closed socket */ + if (self->Socket->sock_fd < 0) { + PyErr_SetString(PySSLErrorObject, + "Underlying socket has been closed."); + return NULL; + } + + PySSL_BEGIN_ALLOW_THREADS + err = SSL_shutdown(self->ssl); + if (err == 0) { + /* we need to call it again to finish the shutdown */ + err = SSL_shutdown(self->ssl); + } + PySSL_END_ALLOW_THREADS + + if (err < 0) + return PySSL_SetError(self, err, __FILE__, __LINE__); + else { + Py_INCREF(self->Socket); + return (PyObject *) (self->Socket); + } +} + +PyDoc_STRVAR(PySSL_SSLshutdown_doc, +"shutdown(s) -> socket\n\ +\n\ +Does the SSL shutdown handshake with the remote end, and returns\n\ +the underlying socket object."); + static PyMethodDef PySSLMethods[] = { + {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS}, {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, PySSL_SSLwrite_doc}, {"read", (PyCFunction)PySSL_SSLread, METH_VARARGS, PySSL_SSLread_doc}, + {"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS, + PySSL_SSLpending_doc}, {"server", (PyCFunction)PySSL_server, METH_NOARGS}, {"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS}, {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, PySSL_peercert_doc}, {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, + {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, + PySSL_SSLshutdown_doc}, {NULL, NULL} };