diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index 5b975aa7315..5f674f99d77 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -241,7 +241,7 @@ class _FlowControlMixin(Transport): def __init__(self, extra=None): super().__init__(extra) self._protocol_paused = False - self.set_write_buffer_limits() + self._set_write_buffer_limits() def _maybe_pause_protocol(self): size = self.get_write_buffer_size() @@ -273,7 +273,7 @@ class _FlowControlMixin(Transport): 'protocol': self._protocol, }) - def set_write_buffer_limits(self, high=None, low=None): + def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: high = 64*1024 @@ -287,5 +287,9 @@ class _FlowControlMixin(Transport): self._high_water = high self._low_water = low + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + def get_write_buffer_size(self): raise NotImplementedError diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py index d16db8074b6..4c645268d94 100644 --- a/Lib/test/test_asyncio/test_transports.py +++ b/Lib/test/test_asyncio/test_transports.py @@ -4,6 +4,7 @@ import unittest import unittest.mock import asyncio +from asyncio import transports class TransportTests(unittest.TestCase): @@ -60,6 +61,28 @@ class TransportTests(unittest.TestCase): self.assertRaises(NotImplementedError, transport.terminate) self.assertRaises(NotImplementedError, transport.kill) + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + transport = MyTransport() + transport._protocol = unittest.mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + if __name__ == '__main__': unittest.main()