Fix some minor style nits. (I'll leave adding __all__ and making the

docstrings conform to PEP 8 to someone else.)
This commit is contained in:
Guido van Rossum 2007-11-16 00:06:11 +00:00
parent b61506989e
commit 5b8b1555de
1 changed files with 38 additions and 46 deletions

View File

@ -1,8 +1,7 @@
# Wrapper module for _ssl, providing some additional facilities # Wrapper module for _ssl, providing some additional facilities
# implemented in Python. Written by Bill Janssen. # implemented in Python. Written by Bill Janssen.
"""\ """This module provides some more Pythonic support for SSL.
This module provides some more Pythonic support for SSL.
Object types: Object types:
@ -61,18 +60,20 @@ import _ssl # if we can't import it, let the error propagate
from _ssl import SSLError from _ssl import SSLError
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
PROTOCOL_TLSv1)
from _ssl import RAND_status, RAND_egd, RAND_add from _ssl import RAND_status, RAND_egd, RAND_add
from _ssl import \ from _ssl import (
SSL_ERROR_ZERO_RETURN, \ SSL_ERROR_ZERO_RETURN,
SSL_ERROR_WANT_READ, \ SSL_ERROR_WANT_READ,
SSL_ERROR_WANT_WRITE, \ SSL_ERROR_WANT_WRITE,
SSL_ERROR_WANT_X509_LOOKUP, \ SSL_ERROR_WANT_X509_LOOKUP,
SSL_ERROR_SYSCALL, \ SSL_ERROR_SYSCALL,
SSL_ERROR_SSL, \ SSL_ERROR_SSL,
SSL_ERROR_WANT_CONNECT, \ SSL_ERROR_WANT_CONNECT,
SSL_ERROR_EOF, \ SSL_ERROR_EOF,
SSL_ERROR_INVALID_ERROR_CODE SSL_ERROR_INVALID_ERROR_CODE,
)
from socket import socket, AF_INET, SOCK_STREAM, error from socket import socket, AF_INET, SOCK_STREAM, error
from socket import getnameinfo as _getnameinfo from socket import getnameinfo as _getnameinfo
@ -80,7 +81,7 @@ from socket import error as socket_error
import base64 # for DER-to-PEM translation import base64 # for DER-to-PEM translation
_can_dup_socket = hasattr(socket, "dup") _can_dup_socket = hasattr(socket, "dup")
class SSLSocket (socket): class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps """This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and the underlying OS socket in an SSL context when necessary, and
@ -102,7 +103,8 @@ class SSLSocket (socket):
if _can_dup_socket: if _can_dup_socket:
nfd = os.dup(fd) nfd = os.dup(fd)
try: try:
wrapper = socket.__init__(self, family=sock.family, type=sock.type, proto=sock.proto, fileno=nfd) socket.__init__(self, family=sock.family, type=sock.type,
proto=sock.proto, fileno=nfd)
except: except:
if nfd != fd: if nfd != fd:
os.close(nfd) os.close(nfd)
@ -152,7 +154,6 @@ class SSLSocket (socket):
pass pass
def read(self, len=1024, buffer=None): def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them. """Read up to LEN bytes and return them.
Return zero-length string on EOF.""" Return zero-length string on EOF."""
@ -169,7 +170,6 @@ class SSLSocket (socket):
raise raise
def write(self, data): def write(self, data):
"""Write DATA to the underlying SSL channel. Returns """Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted.""" number of bytes of DATA actually transmitted."""
@ -177,7 +177,6 @@ class SSLSocket (socket):
return self._sslobj.write(data) return self._sslobj.write(data)
def getpeercert(self, binary_form=False): def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the """Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel. certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a Return None if no certificate was provided, {} if a
@ -186,16 +185,14 @@ class SSLSocket (socket):
self._checkClosed() self._checkClosed()
return self._sslobj.peer_certificate(binary_form) return self._sslobj.peer_certificate(binary_form)
def cipher (self): def cipher(self):
self._checkClosed() self._checkClosed()
if not self._sslobj: if not self._sslobj:
return None return None
else: else:
return self._sslobj.cipher() return self._sslobj.cipher()
def send (self, data, flags=0): def send(self, data, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
if flags != 0: if flags != 0:
@ -217,7 +214,7 @@ class SSLSocket (socket):
else: else:
return socket.send(self, data, flags) return socket.send(self, data, flags)
def send_to (self, data, addr, flags=0): def send_to(self, data, addr, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
raise ValueError("send_to not allowed on instances of %s" % raise ValueError("send_to not allowed on instances of %s" %
@ -225,7 +222,7 @@ class SSLSocket (socket):
else: else:
return socket.send_to(self, data, addr, flags) return socket.send_to(self, data, addr, flags)
def sendall (self, data, flags=0): def sendall(self, data, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
amount = len(data) amount = len(data)
@ -237,13 +234,13 @@ class SSLSocket (socket):
else: else:
return socket.sendall(self, data, flags) return socket.sendall(self, data, flags)
def recv (self, buflen=1024, flags=0): def recv(self, buflen=1024, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
if flags != 0: if flags != 0:
raise ValueError( raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" % "non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__) self.__class__)
while True: while True:
try: try:
return self.read(buflen) return self.read(buflen)
@ -255,7 +252,7 @@ class SSLSocket (socket):
else: else:
return socket.recv(self, buflen, flags) return socket.recv(self, buflen, flags)
def recv_into (self, buffer, nbytes=None, flags=0): def recv_into(self, buffer, nbytes=None, flags=0):
self._checkClosed() self._checkClosed()
if buffer and (nbytes is None): if buffer and (nbytes is None):
nbytes = len(buffer) nbytes = len(buffer)
@ -264,8 +261,8 @@ class SSLSocket (socket):
if self._sslobj: if self._sslobj:
if flags != 0: if flags != 0:
raise ValueError( raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" % "non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__) self.__class__)
while True: while True:
try: try:
v = self.read(nbytes, buffer) v = self.read(nbytes, buffer)
@ -279,7 +276,7 @@ class SSLSocket (socket):
else: else:
return socket.recv_into(self, buffer, nbytes, flags) return socket.recv_into(self, buffer, nbytes, flags)
def recv_from (self, addr, buflen=1024, flags=0): def recv_from(self, addr, buflen=1024, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
raise ValueError("recv_from not allowed on instances of %s" % raise ValueError("recv_from not allowed on instances of %s" %
@ -287,27 +284,26 @@ class SSLSocket (socket):
else: else:
return socket.recv_from(self, addr, buflen, flags) return socket.recv_from(self, addr, buflen, flags)
def pending (self): def pending(self):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:
return self._sslobj.pending() return self._sslobj.pending()
else: else:
return 0 return 0
def shutdown (self, how): def shutdown(self, how):
self._checkClosed() self._checkClosed()
self._sslobj = None self._sslobj = None
socket.shutdown(self, how) socket.shutdown(self, how)
def _real_close (self): def _real_close(self):
self._sslobj = None self._sslobj = None
# self._closed = True # self._closed = True
if self._base: if self._base:
self._base.close() self._base.close()
socket._real_close(self) socket._real_close(self)
def do_handshake (self): def do_handshake(self):
"""Perform a TLS/SSL handshake.""" """Perform a TLS/SSL handshake."""
try: try:
@ -317,7 +313,6 @@ class SSLSocket (socket):
raise raise
def connect(self, addr): def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in """Connects to remote ADDR, and then wraps the connection in
an SSL channel.""" an SSL channel."""
@ -333,7 +328,6 @@ class SSLSocket (socket):
self.do_handshake() self.do_handshake()
def accept(self): def accept(self):
"""Accepts a new connection from a remote client, and returns """Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client.""" SSL channel, and the address of the remote client."""
@ -342,9 +336,11 @@ class SSLSocket (socket):
return (SSLSocket(sock=newsock, return (SSLSocket(sock=newsock,
keyfile=self.keyfile, certfile=self.certfile, keyfile=self.keyfile, certfile=self.certfile,
server_side=True, server_side=True,
cert_reqs=self.cert_reqs, ssl_version=self.ssl_version, cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs, ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect), do_handshake_on_connect=
self.do_handshake_on_connect),
addr) addr)
@ -361,7 +357,6 @@ def wrap_socket(sock, keyfile=None, certfile=None,
# some utility functions # some utility functions
def cert_time_to_seconds(cert_time): def cert_time_to_seconds(cert_time):
"""Takes a date-time string in standard ASN1_print form """Takes a date-time string in standard ASN1_print form
("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
a Python time value in seconds past the epoch.""" a Python time value in seconds past the epoch."""
@ -373,7 +368,6 @@ PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----" PEM_FOOTER = "-----END CERTIFICATE-----"
def DER_cert_to_PEM_cert(der_cert_bytes): def DER_cert_to_PEM_cert(der_cert_bytes):
"""Takes a certificate in binary DER format and returns the """Takes a certificate in binary DER format and returns the
PEM version of it as a string.""" PEM version of it as a string."""
@ -383,7 +377,6 @@ def DER_cert_to_PEM_cert(der_cert_bytes):
PEM_FOOTER + '\n') PEM_FOOTER + '\n')
def PEM_cert_to_DER_cert(pem_cert_string): def PEM_cert_to_DER_cert(pem_cert_string):
"""Takes a certificate in ASCII PEM format and returns the """Takes a certificate in ASCII PEM format and returns the
DER-encoded version of it as a byte sequence""" DER-encoded version of it as a byte sequence"""
@ -396,8 +389,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):
d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
return base64.decodestring(d.encode('ASCII', 'strict')) return base64.decodestring(d.encode('ASCII', 'strict'))
def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve the certificate from the server at the specified address, """Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string. and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it. If 'ca_certs' is specified, validate the server cert against it.
@ -415,7 +407,7 @@ def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
s.close() s.close()
return DER_cert_to_PEM_cert(dercert) return DER_cert_to_PEM_cert(dercert)
def get_protocol_name (protocol_code): def get_protocol_name(protocol_code):
if protocol_code == PROTOCOL_TLSv1: if protocol_code == PROTOCOL_TLSv1:
return "TLSv1" return "TLSv1"
elif protocol_code == PROTOCOL_SSLv23: elif protocol_code == PROTOCOL_SSLv23: