diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 3c909a6a3a5..aed4ac004c9 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -223,8 +223,7 @@ static PyTypeObject PySSLMemoryBIO_Type; static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args); static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args); -static int check_socket_and_wait_for_timeout(PySocketSockObject *s, - int writing); +static int PySSL_select(PySocketSockObject *s, int writing); static PyObject *PySSL_peercert(PySSLSocket *self, PyObject *args); static PyObject *PySSL_cipher(PySSLSocket *self); @@ -588,15 +587,18 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) ret = SSL_do_handshake(self->ssl); err = SSL_get_error(self->ssl, ret); PySSL_END_ALLOW_THREADS + if (PyErr_CheckSignals()) goto error; + if (err == SSL_ERROR_WANT_READ) { - sockstate = check_socket_and_wait_for_timeout(sock, 0); + sockstate = PySSL_select(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - sockstate = check_socket_and_wait_for_timeout(sock, 1); + sockstate = PySSL_select(sock, 1); } else { sockstate = SOCKET_OPERATION_OK; } + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySocketModule.timeout_error, ERRSTR("The handshake operation timed out")); @@ -1609,11 +1611,17 @@ static void PySSL_dealloc(PySSLSocket *self) */ static int -check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) +PySSL_select(PySocketSockObject *s, int writing) { + int rc; +#ifdef HAVE_POLL + struct pollfd pollfd; + _PyTime_t ms; +#else + int nfds; fd_set fds; struct timeval tv; - int rc; +#endif /* Nothing to do unless we're in timeout mode (not non-blocking) */ if ((s == NULL) || (s->sock_timeout == 0)) @@ -1628,25 +1636,17 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) /* Prefer poll, if available, since you can poll() any fd * which can't be done with select(). */ #ifdef HAVE_POLL - { - struct pollfd pollfd; - int timeout; + pollfd.fd = s->sock_fd; + pollfd.events = writing ? POLLOUT : POLLIN; - pollfd.fd = s->sock_fd; - pollfd.events = writing ? POLLOUT : POLLIN; - - /* s->sock_timeout is in seconds, timeout in ms */ - timeout = (int)_PyTime_AsMilliseconds(s->sock_timeout, - _PyTime_ROUND_CEILING); - - PySSL_BEGIN_ALLOW_THREADS - rc = poll(&pollfd, 1, timeout); - PySSL_END_ALLOW_THREADS - - goto normal_return; - } -#endif + /* s->sock_timeout is in seconds, timeout in ms */ + ms = (int)_PyTime_AsMilliseconds(s->sock_timeout, _PyTime_ROUND_CEILING); + assert(ms <= INT_MAX); + PySSL_BEGIN_ALLOW_THREADS + rc = poll(&pollfd, 1, (int)ms); + PySSL_END_ALLOW_THREADS +#else /* Guard against socket too large for select*/ if (!_PyIsSelectable_fd(s->sock_fd)) return SOCKET_TOO_LARGE_FOR_SELECT; @@ -1656,19 +1656,16 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) FD_ZERO(&fds); FD_SET(s->sock_fd, &fds); - /* See if the socket is ready */ + /* Wait until the socket becomes ready */ PySSL_BEGIN_ALLOW_THREADS + nfds = Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int); if (writing) - rc = select(Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int), - NULL, &fds, NULL, &tv); + rc = select(nfds, NULL, &fds, NULL, &tv); else - rc = select(Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int), - &fds, NULL, NULL, &tv); + rc = select(nfds, &fds, NULL, NULL, &tv); PySSL_END_ALLOW_THREADS - -#ifdef HAVE_POLL -normal_return: #endif + /* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise (when we are able to write or when there's something to read) */ return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK; @@ -1710,7 +1707,7 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); } - sockstate = check_socket_and_wait_for_timeout(sock, 1); + sockstate = PySSL_select(sock, 1); if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySocketModule.timeout_error, "The write operation timed out"); @@ -1724,21 +1721,24 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) "Underlying socket too large for select()."); goto error; } + do { PySSL_BEGIN_ALLOW_THREADS len = SSL_write(self->ssl, buf.buf, (int)buf.len); err = SSL_get_error(self->ssl, len); PySSL_END_ALLOW_THREADS - if (PyErr_CheckSignals()) { + + if (PyErr_CheckSignals()) goto error; - } + if (err == SSL_ERROR_WANT_READ) { - sockstate = check_socket_and_wait_for_timeout(sock, 0); + sockstate = PySSL_select(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - sockstate = check_socket_and_wait_for_timeout(sock, 1); + sockstate = PySSL_select(sock, 1); } else { sockstate = SOCKET_OPERATION_OK; } + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySocketModule.timeout_error, "The write operation timed out"); @@ -1847,21 +1847,23 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args) count = SSL_read(self->ssl, mem, len); err = SSL_get_error(self->ssl, count); PySSL_END_ALLOW_THREADS + if (PyErr_CheckSignals()) goto error; + if (err == SSL_ERROR_WANT_READ) { - sockstate = check_socket_and_wait_for_timeout(sock, 0); + sockstate = PySSL_select(sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - sockstate = check_socket_and_wait_for_timeout(sock, 1); - } else if ((err == SSL_ERROR_ZERO_RETURN) && - (SSL_get_shutdown(self->ssl) == - SSL_RECEIVED_SHUTDOWN)) + sockstate = PySSL_select(sock, 1); + } else if (err == SSL_ERROR_ZERO_RETURN && + SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN) { count = 0; goto done; - } else { - sockstate = SOCKET_OPERATION_OK; } + else + sockstate = SOCKET_OPERATION_OK; + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySocketModule.timeout_error, "The read operation timed out"); @@ -1870,6 +1872,7 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args) break; } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); + if (count <= 0) { PySSL_SetError(self, count, __FILE__, __LINE__); goto error; @@ -1935,6 +1938,7 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self) SSL_set_read_ahead(self->ssl, 0); err = SSL_shutdown(self->ssl); PySSL_END_ALLOW_THREADS + /* If err == 1, a secure shutdown with SSL_shutdown() is complete */ if (err > 0) break; @@ -1952,11 +1956,12 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self) /* Possibly retry shutdown until timeout or failure */ ssl_err = SSL_get_error(self->ssl, err); if (ssl_err == SSL_ERROR_WANT_READ) - sockstate = check_socket_and_wait_for_timeout(sock, 0); + sockstate = PySSL_select(sock, 0); else if (ssl_err == SSL_ERROR_WANT_WRITE) - sockstate = check_socket_and_wait_for_timeout(sock, 1); + sockstate = PySSL_select(sock, 1); else break; + if (sockstate == SOCKET_HAS_TIMED_OUT) { if (ssl_err == SSL_ERROR_WANT_READ) PyErr_SetString(PySocketModule.timeout_error,