clean up ssl.py; expose unwrap and add test for it
This commit is contained in:
parent
6aa2d1fec7
commit
40a0f66e95
10
Lib/ssl.py
10
Lib/ssl.py
|
@ -75,10 +75,10 @@ from _ssl import (
|
|||
SSL_ERROR_INVALID_ERROR_CODE,
|
||||
)
|
||||
|
||||
from socket import socket, AF_INET, SOCK_STREAM, error
|
||||
from socket import getnameinfo as _getnameinfo
|
||||
from socket import error as socket_error
|
||||
from socket import dup as _dup
|
||||
from socket import socket, AF_INET, SOCK_STREAM
|
||||
import base64 # for DER-to-PEM translation
|
||||
import traceback
|
||||
|
||||
|
@ -296,6 +296,14 @@ class SSLSocket(socket):
|
|||
self._sslobj = None
|
||||
socket.shutdown(self, how)
|
||||
|
||||
def unwrap (self):
|
||||
if self._sslobj:
|
||||
s = self._sslobj.shutdown()
|
||||
self._sslobj = None
|
||||
return s
|
||||
else:
|
||||
raise ValueError("No SSL wrapper around " + str(self))
|
||||
|
||||
def _real_close(self):
|
||||
self._sslobj = None
|
||||
# self._closed = True
|
||||
|
|
|
@ -279,6 +279,15 @@ else:
|
|||
self.write("OK\n".encode("ASCII", "strict"))
|
||||
if not self.wrap_conn():
|
||||
return
|
||||
elif (self.server.starttls_server and self.sslconn
|
||||
and amsg.strip() == 'ENDTLS'):
|
||||
if support.verbose and self.server.connectionchatty:
|
||||
sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
|
||||
self.write("OK\n".encode("ASCII", "strict"))
|
||||
self.sock = self.sslconn.unwrap()
|
||||
self.sslconn = None
|
||||
if support.verbose and self.server.connectionchatty:
|
||||
sys.stdout.write(" server: connection is now unencrypted...\n")
|
||||
else:
|
||||
if (support.verbose and
|
||||
self.server.connectionchatty):
|
||||
|
@ -868,7 +877,7 @@ else:
|
|||
|
||||
def testSTARTTLS (self):
|
||||
|
||||
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4")
|
||||
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
|
||||
|
||||
server = ThreadedEchoServer(CERTFILE,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||
|
@ -910,8 +919,16 @@ else:
|
|||
" client: read %s from server, starting TLS...\n"
|
||||
% repr(msg))
|
||||
conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
|
||||
|
||||
wrapped = True
|
||||
elif (indata == "ENDTLS" and
|
||||
str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
|
||||
if support.verbose:
|
||||
msg = str(outdata, 'ASCII', 'replace')
|
||||
sys.stdout.write(
|
||||
" client: read %s from server, ending TLS...\n"
|
||||
% repr(msg))
|
||||
s = conn.unwrap()
|
||||
wrapped = False
|
||||
else:
|
||||
if support.verbose:
|
||||
msg = str(outdata, 'ASCII', 'replace')
|
||||
|
@ -922,7 +939,7 @@ else:
|
|||
if wrapped:
|
||||
conn.write("over\n".encode("ASCII", "strict"))
|
||||
else:
|
||||
s.send("over\n")
|
||||
s.send("over\n".encode("ASCII", "strict"))
|
||||
if wrapped:
|
||||
conn.close()
|
||||
else:
|
||||
|
|
|
@ -1370,6 +1370,42 @@ PyDoc_STRVAR(PySSL_SSLread_doc,
|
|||
\n\
|
||||
Read up to len bytes from the SSL socket.");
|
||||
|
||||
static PyObject *PySSL_SSLshutdown(PySSLObject *self)
|
||||
{
|
||||
int err;
|
||||
PySocketSockObject *sock
|
||||
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
|
||||
|
||||
/* Guard against closed socket */
|
||||
if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) {
|
||||
_setSSLError("Underlying socket connection gone",
|
||||
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PySSL_BEGIN_ALLOW_THREADS
|
||||
err = SSL_shutdown(self->ssl);
|
||||
if (err == 0) {
|
||||
/* we need to call it again to finish the shutdown */
|
||||
err = SSL_shutdown(self->ssl);
|
||||
}
|
||||
PySSL_END_ALLOW_THREADS
|
||||
|
||||
if (err < 0)
|
||||
return PySSL_SetError(self, err, __FILE__, __LINE__);
|
||||
else {
|
||||
Py_INCREF(sock);
|
||||
return (PyObject *) sock;
|
||||
}
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(PySSL_SSLshutdown_doc,
|
||||
"shutdown(s) -> socket\n\
|
||||
\n\
|
||||
Does the SSL shutdown handshake with the remote end, and returns\n\
|
||||
the underlying socket object.");
|
||||
|
||||
|
||||
static PyMethodDef PySSLMethods[] = {
|
||||
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
|
||||
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
|
||||
|
@ -1381,6 +1417,8 @@ static PyMethodDef PySSLMethods[] = {
|
|||
{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
|
||||
PySSL_peercert_doc},
|
||||
{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
|
||||
{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
|
||||
PySSL_SSLshutdown_doc},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
|
@ -1480,6 +1518,8 @@ fails or if it does provide enough data to seed PRNG.");
|
|||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
/* List of functions exported by this module. */
|
||||
|
||||
static PyMethodDef PySSL_methods[] = {
|
||||
|
|
Loading…
Reference in New Issue