Issue #3476: make BufferedReader and BufferedWriter thread-safe

This commit is contained in:
Antoine Pitrou 2008-08-14 21:04:30 +00:00
parent 63d325e8c4
commit 11ec65d82b
3 changed files with 144 additions and 40 deletions

102
Lib/io.py
View File

@ -63,6 +63,7 @@ import sys
import codecs import codecs
import _fileio import _fileio
import warnings import warnings
import threading
# open() uses st_blksize whenever we can # open() uses st_blksize whenever we can
DEFAULT_BUFFER_SIZE = 8 * 1024 # bytes DEFAULT_BUFFER_SIZE = 8 * 1024 # bytes
@ -908,6 +909,7 @@ class BufferedReader(_BufferedIOMixin):
_BufferedIOMixin.__init__(self, raw) _BufferedIOMixin.__init__(self, raw)
self.buffer_size = buffer_size self.buffer_size = buffer_size
self._reset_read_buf() self._reset_read_buf()
self._read_lock = threading.Lock()
def _reset_read_buf(self): def _reset_read_buf(self):
self._read_buf = b"" self._read_buf = b""
@ -921,6 +923,10 @@ class BufferedReader(_BufferedIOMixin):
mode. If n is negative, read until EOF or until read() would mode. If n is negative, read until EOF or until read() would
block. block.
""" """
with self._read_lock:
return self._read_unlocked(n)
def _read_unlocked(self, n=None):
nodata_val = b"" nodata_val = b""
empty_values = (b"", None) empty_values = (b"", None)
buf = self._read_buf buf = self._read_buf
@ -973,6 +979,10 @@ class BufferedReader(_BufferedIOMixin):
do at most one raw read to satisfy it. We never return more do at most one raw read to satisfy it. We never return more
than self.buffer_size. than self.buffer_size.
""" """
with self._read_lock:
return self._peek_unlocked(n)
def _peek_unlocked(self, n=0):
want = min(n, self.buffer_size) want = min(n, self.buffer_size)
have = len(self._read_buf) - self._read_pos have = len(self._read_buf) - self._read_pos
if have < want: if have < want:
@ -989,18 +999,21 @@ class BufferedReader(_BufferedIOMixin):
# only return buffered bytes. Otherwise, we do one raw read. # only return buffered bytes. Otherwise, we do one raw read.
if n <= 0: if n <= 0:
return b"" return b""
self.peek(1) with self._read_lock:
return self.read(min(n, len(self._read_buf) - self._read_pos)) self._peek_unlocked(1)
return self._read_unlocked(
min(n, len(self._read_buf) - self._read_pos))
def tell(self): def tell(self):
return self.raw.tell() - len(self._read_buf) + self._read_pos return self.raw.tell() - len(self._read_buf) + self._read_pos
def seek(self, pos, whence=0): def seek(self, pos, whence=0):
if whence == 1: with self._read_lock:
pos -= len(self._read_buf) - self._read_pos if whence == 1:
pos = self.raw.seek(pos, whence) pos -= len(self._read_buf) - self._read_pos
self._reset_read_buf() pos = self.raw.seek(pos, whence)
return pos self._reset_read_buf()
return pos
class BufferedWriter(_BufferedIOMixin): class BufferedWriter(_BufferedIOMixin):
@ -1022,43 +1035,51 @@ class BufferedWriter(_BufferedIOMixin):
if max_buffer_size is None if max_buffer_size is None
else max_buffer_size) else max_buffer_size)
self._write_buf = bytearray() self._write_buf = bytearray()
self._write_lock = threading.Lock()
def write(self, b): def write(self, b):
if self.closed: if self.closed:
raise ValueError("write to closed file") raise ValueError("write to closed file")
if isinstance(b, unicode): if isinstance(b, unicode):
raise TypeError("can't write unicode to binary stream") raise TypeError("can't write unicode to binary stream")
# XXX we can implement some more tricks to try and avoid partial writes with self._write_lock:
if len(self._write_buf) > self.buffer_size: # XXX we can implement some more tricks to try and avoid
# We're full, so let's pre-flush the buffer # partial writes
try: if len(self._write_buf) > self.buffer_size:
self.flush() # We're full, so let's pre-flush the buffer
except BlockingIOError as e: try:
# We can't accept anything else. self._flush_unlocked()
# XXX Why not just let the exception pass through? except BlockingIOError as e:
raise BlockingIOError(e.errno, e.strerror, 0) # We can't accept anything else.
before = len(self._write_buf) # XXX Why not just let the exception pass through?
self._write_buf.extend(b) raise BlockingIOError(e.errno, e.strerror, 0)
written = len(self._write_buf) - before before = len(self._write_buf)
if len(self._write_buf) > self.buffer_size: self._write_buf.extend(b)
try: written = len(self._write_buf) - before
self.flush() if len(self._write_buf) > self.buffer_size:
except BlockingIOError as e: try:
if (len(self._write_buf) > self.max_buffer_size): self._flush_unlocked()
# We've hit max_buffer_size. We have to accept a partial except BlockingIOError as e:
# write and cut back our buffer. if len(self._write_buf) > self.max_buffer_size:
overage = len(self._write_buf) - self.max_buffer_size # We've hit max_buffer_size. We have to accept a
self._write_buf = self._write_buf[:self.max_buffer_size] # partial write and cut back our buffer.
raise BlockingIOError(e.errno, e.strerror, overage) overage = len(self._write_buf) - self.max_buffer_size
return written self._write_buf = self._write_buf[:self.max_buffer_size]
raise BlockingIOError(e.errno, e.strerror, overage)
return written
def truncate(self, pos=None): def truncate(self, pos=None):
self.flush() with self._write_lock:
if pos is None: self._flush_unlocked()
pos = self.raw.tell() if pos is None:
return self.raw.truncate(pos) pos = self.raw.tell()
return self.raw.truncate(pos)
def flush(self): def flush(self):
with self._write_lock:
self._flush_unlocked()
def _flush_unlocked(self):
if self.closed: if self.closed:
raise ValueError("flush of closed file") raise ValueError("flush of closed file")
written = 0 written = 0
@ -1077,8 +1098,9 @@ class BufferedWriter(_BufferedIOMixin):
return self.raw.tell() + len(self._write_buf) return self.raw.tell() + len(self._write_buf)
def seek(self, pos, whence=0): def seek(self, pos, whence=0):
self.flush() with self._write_lock:
return self.raw.seek(pos, whence) self._flush_unlocked()
return self.raw.seek(pos, whence)
class BufferedRWPair(BufferedIOBase): class BufferedRWPair(BufferedIOBase):
@ -1168,7 +1190,8 @@ class BufferedRandom(BufferedWriter, BufferedReader):
# First do the raw seek, then empty the read buffer, so that # First do the raw seek, then empty the read buffer, so that
# if the raw seek fails, we don't lose buffered data forever. # if the raw seek fails, we don't lose buffered data forever.
pos = self.raw.seek(pos, whence) pos = self.raw.seek(pos, whence)
self._reset_read_buf() with self._read_lock:
self._reset_read_buf()
return pos return pos
def tell(self): def tell(self):
@ -1205,8 +1228,9 @@ class BufferedRandom(BufferedWriter, BufferedReader):
def write(self, b): def write(self, b):
if self._read_buf: if self._read_buf:
# Undo readahead # Undo readahead
self.raw.seek(self._read_pos - len(self._read_buf), 1) with self._read_lock:
self._reset_read_buf() self.raw.seek(self._read_pos - len(self._read_buf), 1)
self._reset_read_buf()
return BufferedWriter.write(self, b) return BufferedWriter.write(self, b)

View File

@ -6,8 +6,10 @@ import os
import sys import sys
import time import time
import array import array
import threading
import random
import unittest import unittest
from itertools import chain from itertools import chain, cycle
from test import test_support from test import test_support
import codecs import codecs
@ -390,6 +392,49 @@ class BufferedReaderTest(unittest.TestCase):
# this test. Else, write it. # this test. Else, write it.
pass pass
def testThreads(self):
try:
# Write out many bytes with exactly the same number of 0's,
# 1's... 255's. This will help us check that concurrent reading
# doesn't duplicate or forget contents.
N = 1000
l = range(256) * N
random.shuffle(l)
s = bytes(bytearray(l))
with io.open(test_support.TESTFN, "wb") as f:
f.write(s)
with io.open(test_support.TESTFN, "rb", buffering=0) as raw:
bufio = io.BufferedReader(raw, 8)
errors = []
results = []
def f():
try:
# Intra-buffer read then buffer-flushing read
for n in cycle([1, 19]):
s = bufio.read(n)
if not s:
break
# list.append() is atomic
results.append(s)
except Exception as e:
errors.append(e)
raise
threads = [threading.Thread(target=f) for x in range(20)]
for t in threads:
t.start()
time.sleep(0.02) # yield
for t in threads:
t.join()
self.assertFalse(errors,
"the following exceptions were caught: %r" % errors)
s = b''.join(results)
for i in range(256):
c = bytes(bytearray([i]))
self.assertEqual(s.count(c), N)
finally:
test_support.unlink(test_support.TESTFN)
class BufferedWriterTest(unittest.TestCase): class BufferedWriterTest(unittest.TestCase):
@ -446,6 +491,38 @@ class BufferedWriterTest(unittest.TestCase):
self.assertEquals(b"abc", writer._write_stack[0]) self.assertEquals(b"abc", writer._write_stack[0])
def testThreads(self):
# BufferedWriter should not raise exceptions or crash
# when called from multiple threads.
try:
# We use a real file object because it allows us to
# exercise situations where the GIL is released before
# writing the buffer to the raw streams. This is in addition
# to concurrency issues due to switching threads in the middle
# of Python code.
with io.open(test_support.TESTFN, "wb", buffering=0) as raw:
bufio = io.BufferedWriter(raw, 8)
errors = []
def f():
try:
# Write enough bytes to flush the buffer
s = b"a" * 19
for i in range(50):
bufio.write(s)
except Exception as e:
errors.append(e)
raise
threads = [threading.Thread(target=f) for x in range(20)]
for t in threads:
t.start()
time.sleep(0.02) # yield
for t in threads:
t.join()
self.assertFalse(errors,
"the following exceptions were caught: %r" % errors)
finally:
test_support.unlink(test_support.TESTFN)
class BufferedRWPairTest(unittest.TestCase): class BufferedRWPairTest(unittest.TestCase):

View File

@ -48,6 +48,9 @@ Core and Builtins
Library Library
------- -------
- Issue #3476: binary buffered reading through the new "io" library is now
thread-safe.
- Silence the DeprecationWarning of rfc822 when it is imported by mimetools - Silence the DeprecationWarning of rfc822 when it is imported by mimetools
since mimetools itself is deprecated. Because modules are cached, all since mimetools itself is deprecated. Because modules are cached, all
subsequent imports of rfc822 will not raise a visible DeprecationWarning. subsequent imports of rfc822 will not raise a visible DeprecationWarning.