mirror of https://github.com/python/cpython
gh-103607: Fix `pause_reading` to work when called from `connection_made` in `asyncio`. (#17425)
Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
This commit is contained in:
parent
dff8e5dc8d
commit
78942ecd9b
|
@ -794,6 +794,8 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
self._buffer = collections.deque()
|
||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||
self._closing = False # Set when close() called.
|
||||
self._paused = False # Set when pause_reading() called
|
||||
|
||||
if self._server is not None:
|
||||
self._server._attach()
|
||||
loop._transports[self._sock_fd] = self
|
||||
|
@ -839,6 +841,25 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
def is_closing(self):
|
||||
return self._closing
|
||||
|
||||
def is_reading(self):
|
||||
return not self.is_closing() and not self._paused
|
||||
|
||||
def pause_reading(self):
|
||||
if not self.is_reading():
|
||||
return
|
||||
self._paused = True
|
||||
self._loop._remove_reader(self._sock_fd)
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r pauses reading", self)
|
||||
|
||||
def resume_reading(self):
|
||||
if self._closing or not self._paused:
|
||||
return
|
||||
self._paused = False
|
||||
self._add_reader(self._sock_fd, self._read_ready)
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r resumes reading", self)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
|
@ -898,9 +919,8 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
return sum(map(len, self._buffer))
|
||||
|
||||
def _add_reader(self, fd, callback, *args):
|
||||
if self._closing:
|
||||
if not self.is_reading():
|
||||
return
|
||||
|
||||
self._loop._add_reader(fd, callback, *args)
|
||||
|
||||
|
||||
|
@ -915,7 +935,6 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
self._read_ready_cb = None
|
||||
super().__init__(loop, sock, protocol, extra, server)
|
||||
self._eof = False
|
||||
self._paused = False
|
||||
self._empty_waiter = None
|
||||
if _HAS_SENDMSG:
|
||||
self._write_ready = self._write_sendmsg
|
||||
|
@ -943,25 +962,6 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
|
||||
super().set_protocol(protocol)
|
||||
|
||||
def is_reading(self):
|
||||
return not self._paused and not self._closing
|
||||
|
||||
def pause_reading(self):
|
||||
if self._closing or self._paused:
|
||||
return
|
||||
self._paused = True
|
||||
self._loop._remove_reader(self._sock_fd)
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r pauses reading", self)
|
||||
|
||||
def resume_reading(self):
|
||||
if self._closing or not self._paused:
|
||||
return
|
||||
self._paused = False
|
||||
self._add_reader(self._sock_fd, self._read_ready)
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r resumes reading", self)
|
||||
|
||||
def _read_ready(self):
|
||||
self._read_ready_cb()
|
||||
|
||||
|
|
|
@ -485,13 +485,21 @@ class _UnixReadPipeTransport(transports.ReadTransport):
|
|||
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
# only start reading when connection_made() has been called
|
||||
self._loop.call_soon(self._loop._add_reader,
|
||||
self._loop.call_soon(self._add_reader,
|
||||
self._fileno, self._read_ready)
|
||||
if waiter is not None:
|
||||
# only wake up the waiter when connection_made() has been called
|
||||
self._loop.call_soon(futures._set_result_unless_cancelled,
|
||||
waiter, None)
|
||||
|
||||
def _add_reader(self, fd, callback):
|
||||
if not self.is_reading():
|
||||
return
|
||||
self._loop._add_reader(fd, callback)
|
||||
|
||||
def is_reading(self):
|
||||
return not self._paused and not self._closing
|
||||
|
||||
def __repr__(self):
|
||||
info = [self.__class__.__name__]
|
||||
if self._pipe is None:
|
||||
|
@ -532,7 +540,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
|
|||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
|
||||
def pause_reading(self):
|
||||
if self._closing or self._paused:
|
||||
if not self.is_reading():
|
||||
return
|
||||
self._paused = True
|
||||
self._loop._remove_reader(self._fileno)
|
||||
|
|
|
@ -447,6 +447,19 @@ class ProactorSocketTransportTests(test_utils.TestCase):
|
|||
|
||||
self.assertFalse(tr.is_reading())
|
||||
|
||||
def test_pause_reading_connection_made(self):
|
||||
tr = self.socket_transport()
|
||||
self.protocol.connection_made.side_effect = lambda _: tr.pause_reading()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertFalse(tr.is_reading())
|
||||
self.loop.assert_no_reader(7)
|
||||
|
||||
tr.resume_reading()
|
||||
self.assertTrue(tr.is_reading())
|
||||
|
||||
tr.close()
|
||||
self.assertFalse(tr.is_reading())
|
||||
|
||||
|
||||
def pause_writing_transport(self, high):
|
||||
tr = self.socket_transport()
|
||||
|
|
|
@ -547,6 +547,22 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
|||
self.assertFalse(tr.is_reading())
|
||||
self.loop.assert_no_reader(7)
|
||||
|
||||
def test_pause_reading_connection_made(self):
|
||||
tr = self.socket_transport()
|
||||
self.protocol.connection_made.side_effect = lambda _: tr.pause_reading()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertFalse(tr.is_reading())
|
||||
self.loop.assert_no_reader(7)
|
||||
|
||||
tr.resume_reading()
|
||||
self.assertTrue(tr.is_reading())
|
||||
self.loop.assert_reader(7, tr._read_ready)
|
||||
|
||||
tr.close()
|
||||
self.assertFalse(tr.is_reading())
|
||||
self.loop.assert_no_reader(7)
|
||||
|
||||
|
||||
def test_read_eof_received_error(self):
|
||||
transport = self.socket_transport()
|
||||
transport.close = mock.Mock()
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Fix :func:`!pause_reading` to work when called from :func:`!connection_made` in :mod:`asyncio`.
|
Loading…
Reference in New Issue