asyncio: WriteTransport.set_write_buffer_size to call _maybe_pause_protocol

This commit is contained in:
Yury Selivanov 2014-02-19 11:10:52 -05:00
parent 03e9cb2b0b
commit 2d01c0a080
2 changed files with 29 additions and 2 deletions

View File

@ -241,7 +241,7 @@ class _FlowControlMixin(Transport):
def __init__(self, extra=None):
super().__init__(extra)
self._protocol_paused = False
self.set_write_buffer_limits()
self._set_write_buffer_limits()
def _maybe_pause_protocol(self):
size = self.get_write_buffer_size()
@ -273,7 +273,7 @@ class _FlowControlMixin(Transport):
'protocol': self._protocol,
})
def set_write_buffer_limits(self, high=None, low=None):
def _set_write_buffer_limits(self, high=None, low=None):
if high is None:
if low is None:
high = 64*1024
@ -287,5 +287,9 @@ class _FlowControlMixin(Transport):
self._high_water = high
self._low_water = low
def set_write_buffer_limits(self, high=None, low=None):
self._set_write_buffer_limits(high=high, low=low)
self._maybe_pause_protocol()
def get_write_buffer_size(self):
raise NotImplementedError

View File

@ -4,6 +4,7 @@ import unittest
import unittest.mock
import asyncio
from asyncio import transports
class TransportTests(unittest.TestCase):
@ -60,6 +61,28 @@ class TransportTests(unittest.TestCase):
self.assertRaises(NotImplementedError, transport.terminate)
self.assertRaises(NotImplementedError, transport.kill)
def test_flowcontrol_mixin_set_write_limits(self):
class MyTransport(transports._FlowControlMixin,
transports.Transport):
def get_write_buffer_size(self):
return 512
transport = MyTransport()
transport._protocol = unittest.mock.Mock()
self.assertFalse(transport._protocol_paused)
with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
transport.set_write_buffer_limits(high=0, low=1)
transport.set_write_buffer_limits(high=1024, low=128)
self.assertFalse(transport._protocol_paused)
transport.set_write_buffer_limits(high=256, low=128)
self.assertTrue(transport._protocol_paused)
if __name__ == '__main__':
unittest.main()