various SSL fixes; issues 1251, 3162, 3212

This commit is contained in:
Bill Janssen 2008-06-28 22:19:33 +00:00
parent a27474c345
commit 934b16d0c2
5 changed files with 528 additions and 348 deletions

View File

@ -54,7 +54,7 @@ Functions, Constants, and Exceptions
network connection. This error is a subtype of :exc:`socket.error`, which
in turn is a subtype of :exc:`IOError`.
.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None)
.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None, do_handshake_on_connect=True, suppress_ragged_eofs=True)
Takes an instance ``sock`` of :class:`socket.socket`, and returns an instance of :class:`ssl.SSLSocket`, a subtype
of :class:`socket.socket`, which wraps the underlying socket in an SSL context.
@ -122,6 +122,18 @@ Functions, Constants, and Exceptions
In some older versions of OpenSSL (for instance, 0.9.7l on OS X 10.4),
an SSLv2 client could not connect to an SSLv23 server.
The parameter ``do_handshake_on_connect`` specifies whether to do the SSL
handshake automatically after doing a :meth:`socket.connect`, or whether the
application program will call it explicitly, by invoking the :meth:`SSLSocket.do_handshake`
method. Calling :meth:`SSLSocket.do_handshake` explicitly gives the program control over
the blocking behavior of the socket I/O involved in the handshake.
The parameter ``suppress_ragged_eofs`` specifies how the :meth:`SSLSocket.read`
method should signal unexpected EOF from the other end of the connection. If specified
as :const:`True` (the default), it returns a normal EOF in response to unexpected
EOF errors raised from the underlying socket; if :const:`False`, it will raise
the exceptions back to the caller.
.. function:: RAND_status()
Returns True if the SSL pseudo-random number generator has been
@ -290,6 +302,25 @@ SSLSocket Objects
number of secret bits being used. If no connection has been
established, returns ``None``.
.. method:: SSLSocket.do_handshake()
Perform a TLS/SSL handshake. If this is used with a non-blocking socket,
it may raise :exc:`SSLError` with an ``arg[0]`` of :const:`SSL_ERROR_WANT_READ`
or :const:`SSL_ERROR_WANT_WRITE`, in which case it must be called again until it
completes successfully. For example, to simulate the behavior of a blocking socket,
one might write::
while True:
try:
s.do_handshake()
break
except ssl.SSLError, err:
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
select.select([s], [], [])
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [])
else:
raise
.. index:: single: certificates
@ -367,6 +398,7 @@ certificate, you need to provide a "CA certs" file, filled with the certificate
chains for each issuer you are willing to trust. Again, this file just
contains these chains concatenated together. For validation, Python will
use the first chain it finds in the file which matches.
Some "standard" root certificates are available from various certification
authorities:
`CACert.org <http://www.cacert.org/index.php?id=3>`_,

View File

@ -74,7 +74,7 @@ from _ssl import \
SSL_ERROR_EOF, \
SSL_ERROR_INVALID_ERROR_CODE
from socket import socket
from socket import socket, _fileobject
from socket import getnameinfo as _getnameinfo
import base64 # for DER-to-PEM translation
@ -86,8 +86,16 @@ class SSLSocket (socket):
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None):
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True):
socket.__init__(self, _sock=sock._sock)
# the initializer for socket trashes the methods (tsk, tsk), so...
self.send = lambda x, flags=0: SSLSocket.send(self, x, flags)
self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags)
self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags)
if certfile and not keyfile:
keyfile = certfile
# see if it's connected
@ -101,18 +109,34 @@ class SSLSocket (socket):
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
if do_handshake_on_connect:
timeout = self.gettimeout()
try:
self.settimeout(None)
self.do_handshake()
finally:
self.settimeout(timeout)
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
return self._sslobj.read(len)
try:
return self._sslobj.read(len)
except SSLError, x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
else:
raise
def write(self, data):
@ -143,16 +167,27 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
return self._sslobj.write(data)
while True:
try:
v = self._sslobj.write(data)
except SSLError, x:
if x.args[0] == SSL_ERROR_WANT_READ:
return 0
elif x.args[0] == SSL_ERROR_WANT_WRITE:
return 0
else:
raise
else:
return v
else:
return socket.send(self, data, flags)
def send_to (self, data, addr, flags=0):
def sendto (self, data, addr, flags=0):
if self._sslobj:
raise ValueError("send_to not allowed on instances of %s" %
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
return socket.send_to(self, data, addr, flags)
return socket.sendto(self, data, addr, flags)
def sendall (self, data, flags=0):
if self._sslobj:
@ -160,7 +195,12 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
return self._sslobj.write(data)
amount = len(data)
count = 0
while (count < amount):
v = self.send(data[count:])
count += v
return amount
else:
return socket.sendall(self, data, flags)
@ -170,25 +210,51 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
return self._sslobj.read(data, buflen)
while True:
try:
return self.read(buflen)
except SSLError, x:
if x.args[0] == SSL_ERROR_WANT_READ:
continue
else:
raise x
else:
return socket.recv(self, buflen, flags)
def recv_from (self, addr, buflen=1024, flags=0):
def recvfrom (self, addr, buflen=1024, flags=0):
if self._sslobj:
raise ValueError("recv_from not allowed on instances of %s" %
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return socket.recv_from(self, addr, buflen, flags)
return socket.recvfrom(self, addr, buflen, flags)
def shutdown(self, how):
def pending (self):
if self._sslobj:
return self._sslobj.pending()
else:
return 0
def shutdown (self, how):
self._sslobj = None
socket.shutdown(self, how)
def close(self):
def close (self):
self._sslobj = None
socket.close(self)
def close (self):
if self._makefile_refs < 1:
self._sslobj = None
socket.close(self)
else:
self._makefile_refs -= 1
def do_handshake (self):
"""Perform a TLS/SSL handshake."""
self._sslobj.do_handshake()
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
@ -202,6 +268,8 @@ class SSLSocket (socket):
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self):
@ -210,260 +278,39 @@ class SSLSocket (socket):
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
return (SSLSocket(newsock, True, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs), addr)
return (SSLSocket(newsock,
keyfile=self.keyfile,
certfile=self.certfile,
server_side=True,
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs),
addr)
def makefile(self, mode='r', bufsize=-1):
"""Ouch. Need to make and return a file-like object that
works with the SSL connection."""
if self._sslobj:
return SSLFileStream(self._sslobj, mode, bufsize)
else:
return socket.makefile(self, mode, bufsize)
class SSLFileStream:
"""A class to simulate a file stream on top of a socket.
Most of this is just lifted from the socket module, and
adjusted to work with an SSL stream instead of a socket."""
default_bufsize = 8192
name = "<SSL stream>"
__slots__ = ["mode", "bufsize", "softspace",
# "closed" is a property, see below
"_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
"_close", "_fileno"]
def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
self._sslobj = sslobj
self.mode = mode # Not actually used in this version
if bufsize < 0:
bufsize = self.default_bufsize
self.bufsize = bufsize
self.softspace = False
if bufsize == 0:
self._rbufsize = 1
elif bufsize == 1:
self._rbufsize = self.default_bufsize
else:
self._rbufsize = bufsize
self._wbufsize = bufsize
self._rbuf = "" # A string
self._wbuf = [] # A list of strings
self._close = close
self._fileno = -1
def _getclosed(self):
return self._sslobj is None
closed = property(_getclosed, doc="True if the file is closed")
def fileno(self):
return self._fileno
def close(self):
try:
if self._sslobj:
self.flush()
finally:
if self._close and self._sslobj:
self._sslobj.close()
self._sslobj = None
def __del__(self):
try:
self.close()
except:
# close() may fail if __init__ didn't complete
pass
def flush(self):
if self._wbuf:
buffer = "".join(self._wbuf)
self._wbuf = []
count = 0
while (count < len(buffer)):
written = self._sslobj.write(buffer)
count += written
buffer = buffer[written:]
def write(self, data):
data = str(data) # XXX Should really reject non-string non-buffers
if not data:
return
self._wbuf.append(data)
if (self._wbufsize == 0 or
self._wbufsize == 1 and '\n' in data or
self._get_wbuf_len() >= self._wbufsize):
self.flush()
def writelines(self, list):
# XXX We could do better here for very long lists
# XXX Should really reject non-string non-buffers
self._wbuf.extend(filter(None, map(str, list)))
if (self._wbufsize <= 1 or
self._get_wbuf_len() >= self._wbufsize):
self.flush()
def _get_wbuf_len(self):
buf_len = 0
for x in self._wbuf:
buf_len += len(x)
return buf_len
def read(self, size=-1):
data = self._rbuf
if size < 0:
# Read until EOF
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
if self._rbufsize <= 1:
recv_size = self.default_bufsize
else:
recv_size = self._rbufsize
while True:
data = self._sslobj.read(recv_size)
if not data:
break
buffers.append(data)
return "".join(buffers)
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
left = size - buf_len
recv_size = max(self._rbufsize, left)
data = self._sslobj.read(recv_size)
if not data:
break
buffers.append(data)
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
def readline(self, size=-1):
data = self._rbuf
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
assert data == ""
buffers = []
while data != "\n":
data = self._sslobj.read(1)
if not data:
break
buffers.append(data)
return "".join(buffers)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self._sslobj.read(self._rbufsize)
if not data:
break
buffers.append(data)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
return "".join(buffers)
else:
# Read until size bytes or \n or EOF seen, whichever comes first
nl = data.find('\n', 0, size)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self._sslobj.read(self._rbufsize)
if not data:
break
buffers.append(data)
left = size - buf_len
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
def readlines(self, sizehint=0):
total = 0
list = []
while True:
line = self.readline()
if not line:
break
list.append(line)
total += len(line)
if sizehint and total >= sizehint:
break
return list
# Iterator protocols
def __iter__(self):
return self
def next(self):
line = self.readline()
if not line:
raise StopIteration
return line
self._makefile_refs += 1
return _fileobject(self, mode, bufsize)
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None):
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True):
return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs)
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs)
# some utility functions
@ -549,5 +396,7 @@ def sslwrap_simple (sock, keyfile=None, certfile=None):
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None)
ssl_sock = _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None)
ssl_sock.do_handshake()
return ssl_sock

View File

@ -3,7 +3,9 @@
import sys
import unittest
from test import test_support
import asyncore
import socket
import select
import errno
import subprocess
import time
@ -97,8 +99,7 @@ class BasicTests(unittest.TestCase):
if (d1 != d2):
raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
class NetworkTests(unittest.TestCase):
class NetworkedTests(unittest.TestCase):
def testConnect(self):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@ -130,6 +131,31 @@ class NetworkTests(unittest.TestCase):
finally:
s.close()
def testNonBlockingHandshake(self):
s = socket.socket(socket.AF_INET)
s.connect(("svn.python.org", 443))
s.setblocking(False)
s = ssl.wrap_socket(s,
cert_reqs=ssl.CERT_NONE,
do_handshake_on_connect=False)
count = 0
while True:
try:
count += 1
s.do_handshake()
break
except ssl.SSLError, err:
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
select.select([s], [], [])
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [])
else:
raise
s.close()
if test_support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def testFetchServerCert(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
@ -176,6 +202,18 @@ else:
threading.Thread.__init__(self)
self.setDaemon(True)
def show_conn_details(self):
if self.server.certreqs == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
cert_binary = self.sslconn.getpeercert(True)
if test_support.verbose and self.server.chatty:
sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
cipher = self.sslconn.cipher()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
def wrap_conn (self):
try:
self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
@ -187,6 +225,7 @@ else:
if self.server.chatty:
handle_error("\n server: bad connection attempt from " +
str(self.sock.getpeername()) + ":\n")
self.close()
if not self.server.expect_bad_connects:
# here, we want to stop the server, because this shouldn't
# happen in the context of our test case
@ -197,16 +236,6 @@ else:
return False
else:
if self.server.certreqs == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
cert_binary = self.sslconn.getpeercert(True)
if test_support.verbose and self.server.chatty:
sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
cipher = self.sslconn.cipher()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
return True
def read(self):
@ -225,13 +254,16 @@ else:
if self.sslconn:
self.sslconn.close()
else:
self.sock.close()
self.sock._sock.close()
def run (self):
self.running = True
if not self.server.starttls_server:
if not self.wrap_conn():
if isinstance(self.sock, ssl.SSLSocket):
self.sslconn = self.sock
elif not self.wrap_conn():
return
self.show_conn_details()
while self.running:
try:
msg = self.read()
@ -270,7 +302,9 @@ else:
def __init__(self, certificate, ssl_version=None,
certreqs=None, cacerts=None, expect_bad_connects=False,
chatty=True, connectionchatty=False, starttls_server=False):
chatty=True, connectionchatty=False, starttls_server=False,
wrap_accepting_socket=False):
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLSv1
if certreqs is None:
@ -284,8 +318,16 @@ else:
self.connectionchatty = connectionchatty
self.starttls_server = starttls_server
self.sock = socket.socket()
self.port = test_support.bind_port(self.sock)
self.flag = None
if wrap_accepting_socket:
self.sock = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.certificate,
cert_reqs = self.certreqs,
ca_certs = self.cacerts,
ssl_version = self.protocol)
if test_support.verbose and self.chatty:
sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
self.port = test_support.bind_port(self.sock)
self.active = False
threading.Thread.__init__(self)
self.setDaemon(False)
@ -316,13 +358,86 @@ else:
except:
if self.chatty:
handle_error("Test server failure:\n")
self.sock.close()
def stop (self):
self.active = False
self.sock.close()
class AsyncoreEchoServer(threading.Thread):
class AsyncoreHTTPSServer(threading.Thread):
class EchoServer (asyncore.dispatcher):
class ConnectionHandler (asyncore.dispatcher_with_send):
def __init__(self, conn, certfile):
asyncore.dispatcher_with_send.__init__(self, conn)
self.socket = ssl.wrap_socket(conn, server_side=True,
certfile=certfile,
do_handshake_on_connect=True)
def readable(self):
if isinstance(self.socket, ssl.SSLSocket):
while self.socket.pending() > 0:
self.handle_read_event()
return True
def handle_read(self):
data = self.recv(1024)
self.send(data.lower())
def handle_close(self):
if test_support.verbose:
sys.stdout.write(" server: closed connection %s\n" % self.socket)
def handle_error(self):
raise
def __init__(self, certfile):
self.certfile = certfile
asyncore.dispatcher.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = test_support.bind_port(self.socket)
self.listen(5)
def handle_accept(self):
sock_obj, addr = self.accept()
if test_support.verbose:
sys.stdout.write(" server: new connection from %s:%s\n" %addr)
self.ConnectionHandler(sock_obj, self.certfile)
def handle_error(self):
raise
def __init__(self, certfile):
self.flag = None
self.active = False
self.server = self.EchoServer(certfile)
self.port = self.server.port
threading.Thread.__init__(self)
self.setDaemon(True)
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run (self):
self.active = True
if self.flag:
self.flag.set()
while self.active:
try:
asyncore.loop(1)
except:
pass
def stop (self):
self.active = False
self.server.close()
class SocketServerHTTPSServer(threading.Thread):
class HTTPSServer(HTTPServer):
@ -335,6 +450,12 @@ else:
self.active_lock = threading.Lock()
self.allow_reuse_address = True
def __str__(self):
return ('<%s %s:%s>' %
(self.__class__.__name__,
self.server_name,
self.server_port))
def get_request (self):
# override this to wrap socket with SSL
sock, addr = self.socket.accept()
@ -421,8 +542,8 @@ else:
# we override this to suppress logging unless "verbose"
if test_support.verbose:
sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" %
(self.server.server_name,
sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
(self.server.server_address,
self.server.server_port,
self.request.cipher(),
self.log_date_time_string(),
@ -440,9 +561,7 @@ else:
self.setDaemon(True)
def __str__(self):
return '<%s %s:%d>' % (self.__class__.__name__,
self.server.server_name,
self.server.server_port)
return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
self.flag = flag
@ -487,14 +606,16 @@ else:
def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n",
chatty=True, connectionchatty=False):
chatty=True, connectionchatty=False,
wrap_accepting_socket=False):
server = ThreadedEchoServer(certfile,
certreqs=certreqs,
ssl_version=protocol,
cacerts=cacertsfile,
chatty=chatty,
connectionchatty=connectionchatty)
connectionchatty=connectionchatty,
wrap_accepting_socket=wrap_accepting_socket)
flag = threading.Event()
server.start(flag)
# wait for it to start
@ -572,7 +693,7 @@ else:
ssl.get_protocol_name(server_protocol)))
class ConnectedTests(unittest.TestCase):
class ThreadedTests(unittest.TestCase):
def testRudeShutdown(self):
@ -600,7 +721,7 @@ else:
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(s)
except socket.sslerror:
except IOError:
pass
else:
raise test_support.TestFailed(
@ -680,6 +801,9 @@ else:
def testMalformedCert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"badcert.pem"))
def testWrongCert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"wrongcert.pem"))
def testMalformedKey(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"badkey.pem"))
@ -796,9 +920,9 @@ else:
server.stop()
server.join()
def testAsyncore(self):
def testSocketServer(self):
server = AsyncoreHTTPSServer(CERTFILE)
server = SocketServerHTTPSServer(CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
@ -810,8 +934,8 @@ else:
d1 = open(CERTFILE, 'rb').read()
d2 = ''
# now fetch the same data from the HTTPS server
url = 'https://%s:%d/%s' % (
HOST, server.port, os.path.split(CERTFILE)[1])
url = 'https://127.0.0.1:%d/%s' % (
server.port, os.path.split(CERTFILE)[1])
f = urllib.urlopen(url)
dlen = f.info().getheader("content-length")
if dlen and (int(dlen) > 0):
@ -834,6 +958,58 @@ else:
server.stop()
server.join()
def testWrappedAccept (self):
if test_support.verbose:
sys.stdout.write("\n")
serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
chatty=True, connectionchatty=True,
wrap_accepting_socket=True)
def testAsyncoreServer (self):
indata = "TEST MESSAGE of mixed case\n"
if test_support.verbose:
sys.stdout.write("\n")
server = AsyncoreEchoServer(CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
try:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port))
except ssl.SSLError, x:
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
except Exception, x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if test_support.verbose:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise test_support.TestFailed(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata),20)], len(outdata),
indata[:min(len(indata),20)].lower(), len(indata)))
s.write("over\n")
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
finally:
server.stop()
# wait for server thread to end
server.join()
def test_main(verbose=False):
if skip_expected:
@ -850,15 +1026,19 @@ def test_main(verbose=False):
not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
raise test_support.TestFailed("Can't read certificate files!")
TESTPORT = test_support.find_unused_port()
if not TESTPORT:
raise test_support.TestFailed("Can't find open port to test servers on!")
tests = [BasicTests]
if test_support.is_resource_enabled('network'):
tests.append(NetworkTests)
tests.append(NetworkedTests)
if _have_threads:
thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'):
tests.append(ConnectedTests)
tests.append(ThreadedTests)
test_support.run_unittest(*tests)

32
Lib/test/wrongcert.pem Normal file
View File

@ -0,0 +1,32 @@
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH
FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T
f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB
AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq
1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW
7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg
SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe
19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg
ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666
lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs
ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv
frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk
2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6
+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK
24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G
A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9
OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt
U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ
fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E
usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+
43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw
eSHj5jpC8iZKjCHBn+mAi4cQ514=
-----END CERTIFICATE-----

View File

@ -2,14 +2,15 @@
SSL support based on patches by Brian E Gallew and Laszlo Kovacs.
Re-worked a bit by Bill Janssen to add server-side support and
certificate decoding.
certificate decoding. Chris Stawarz contributed some non-blocking
patches.
This module is imported by ssl.py. It should *not* be used
directly.
XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE?
XXX what about SSL_MODE_AUTO_RETRY
XXX what about SSL_MODE_AUTO_RETRY?
*/
#include "Python.h"
@ -265,8 +266,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
PySSLObject *self;
char *errstr = NULL;
int ret;
int err;
int sockstate;
int verification_mode;
self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
@ -388,57 +387,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
/* Actually negotiate SSL connection */
/* XXX If SSL_connect() returns 0, it's also a failure. */
sockstate = 0;
do {
PySSL_BEGIN_ALLOW_THREADS
if (socket_type == PY_SSL_CLIENT)
ret = SSL_connect(self->ssl);
else
ret = SSL_accept(self->ssl);
err = SSL_get_error(self->ssl, ret);
PySSL_END_ALLOW_THREADS
if(PyErr_CheckSignals()) {
goto fail;
}
if (err == SSL_ERROR_WANT_READ) {
sockstate = check_socket_and_wait_for_timeout(Sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate = check_socket_and_wait_for_timeout(Sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("The connect operation timed out"));
goto fail;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket has been closed."));
goto fail;
} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket too large for select()."));
goto fail;
} else if (sockstate == SOCKET_IS_NONBLOCKING) {
break;
}
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
if (ret < 1) {
PySSL_SetError(self, ret, __FILE__, __LINE__);
goto fail;
}
self->ssl->debug = 1;
PySSL_BEGIN_ALLOW_THREADS
if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
self->server, X509_NAME_MAXLEN);
X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
self->issuer, X509_NAME_MAXLEN);
}
PySSL_END_ALLOW_THREADS
self->Socket = Sock;
Py_INCREF(self->Socket);
return self;
@ -488,6 +436,65 @@ PyDoc_STRVAR(ssl_doc,
/* SSL object methods */
static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
{
int ret;
int err;
int sockstate;
/* Actually negotiate SSL connection */
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
sockstate = 0;
do {
PySSL_BEGIN_ALLOW_THREADS
ret = SSL_do_handshake(self->ssl);
err = SSL_get_error(self->ssl, ret);
PySSL_END_ALLOW_THREADS
if(PyErr_CheckSignals()) {
return NULL;
}
if (err == SSL_ERROR_WANT_READ) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("The handshake operation timed out"));
return NULL;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket has been closed."));
return NULL;
} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
PyErr_SetString(PySSLErrorObject,
ERRSTR("Underlying socket too large for select()."));
return NULL;
} else if (sockstate == SOCKET_IS_NONBLOCKING) {
break;
}
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
if (ret < 1)
return PySSL_SetError(self, ret, __FILE__, __LINE__);
self->ssl->debug = 1;
if (self->peer_cert)
X509_free (self->peer_cert);
PySSL_BEGIN_ALLOW_THREADS
if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
self->server, X509_NAME_MAXLEN);
X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
self->issuer, X509_NAME_MAXLEN);
}
PySSL_END_ALLOW_THREADS
Py_INCREF(Py_None);
return Py_None;
}
static PyObject *
PySSL_server(PySSLObject *self)
{
@ -1127,7 +1134,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing)
rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
PySSL_END_ALLOW_THREADS
#ifdef HAVE_POLL
normal_return:
#endif
/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
(when we are able to write or when there's something to read) */
return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK;
@ -1140,10 +1149,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
int count;
int sockstate;
int err;
int nonblocking;
if (!PyArg_ParseTuple(args, "s#:write", &data, &count))
return NULL;
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
@ -1200,6 +1215,25 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc,
Writes the string s into the SSL object. Returns the number\n\
of bytes written.");
static PyObject *PySSL_SSLpending(PySSLObject *self)
{
int count = 0;
PySSL_BEGIN_ALLOW_THREADS
count = SSL_pending(self->ssl);
PySSL_END_ALLOW_THREADS
if (count < 0)
return PySSL_SetError(self, count, __FILE__, __LINE__);
else
return PyInt_FromLong(count);
}
PyDoc_STRVAR(PySSL_SSLpending_doc,
"pending() -> count\n\
\n\
Returns the number of already decrypted bytes available for read,\n\
pending on the connection.\n");
static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
{
PyObject *buf;
@ -1207,6 +1241,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
int len = 1024;
int sockstate;
int err;
int nonblocking;
if (!PyArg_ParseTuple(args, "|i:read", &len))
return NULL;
@ -1214,6 +1249,11 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
if (!(buf = PyString_FromStringAndSize((char *) 0, len)))
return NULL;
/* just in case the blocking state of the socket has been changed */
nonblocking = (self->Socket->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
/* first check if there are bytes ready to be read */
PySSL_BEGIN_ALLOW_THREADS
count = SSL_pending(self->ssl);
@ -1232,9 +1272,18 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
Py_DECREF(buf);
return NULL;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
/* should contain a zero-length string */
_PyString_Resize(&buf, 0);
return buf;
if (SSL_get_shutdown(self->ssl) !=
SSL_RECEIVED_SHUTDOWN)
{
Py_DECREF(buf);
PyErr_SetString(PySSLErrorObject,
"Socket closed without SSL shutdown handshake");
return NULL;
} else {
/* should contain a zero-length string */
_PyString_Resize(&buf, 0);
return buf;
}
}
}
do {
@ -1285,16 +1334,54 @@ PyDoc_STRVAR(PySSL_SSLread_doc,
\n\
Read up to len bytes from the SSL socket.");
static PyObject *PySSL_SSLshutdown(PySSLObject *self)
{
int err;
/* Guard against closed socket */
if (self->Socket->sock_fd < 0) {
PyErr_SetString(PySSLErrorObject,
"Underlying socket has been closed.");
return NULL;
}
PySSL_BEGIN_ALLOW_THREADS
err = SSL_shutdown(self->ssl);
if (err == 0) {
/* we need to call it again to finish the shutdown */
err = SSL_shutdown(self->ssl);
}
PySSL_END_ALLOW_THREADS
if (err < 0)
return PySSL_SetError(self, err, __FILE__, __LINE__);
else {
Py_INCREF(self->Socket);
return (PyObject *) (self->Socket);
}
}
PyDoc_STRVAR(PySSL_SSLshutdown_doc,
"shutdown(s) -> socket\n\
\n\
Does the SSL shutdown handshake with the remote end, and returns\n\
the underlying socket object.");
static PyMethodDef PySSLMethods[] = {
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
PySSL_SSLwrite_doc},
{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS,
PySSL_SSLread_doc},
{"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS,
PySSL_SSLpending_doc},
{"server", (PyCFunction)PySSL_server, METH_NOARGS},
{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS},
{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
PySSL_peercert_doc},
{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
PySSL_SSLshutdown_doc},
{NULL, NULL}
};