diff --git a/Doc/whatsnew/3.4.rst b/Doc/whatsnew/3.4.rst index c626abd3cd3..94fd8b4f85d 100644 --- a/Doc/whatsnew/3.4.rst +++ b/Doc/whatsnew/3.4.rst @@ -374,6 +374,8 @@ sunau The :meth:`~sunau.getparams` method now returns a namedtuple rather than a plain tuple. (Contributed by Claudiu Popa in :issue:`18901`.) +:meth:`sunau.open` now supports the context manager protocol (:issue:`18878`). + urllib ------ diff --git a/Lib/sunau.py b/Lib/sunau.py index 010ce2313b3..efdc146095a 100644 --- a/Lib/sunau.py +++ b/Lib/sunau.py @@ -168,6 +168,12 @@ class Au_read: if self._file: self.close() + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + def initfp(self, file): self._file = file self._soundpos = 0 @@ -303,6 +309,12 @@ class Au_write: self.close() self._file = None + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + def initfp(self, file): self._file = file self._framerate = 0 @@ -410,14 +422,17 @@ class Au_write: self._patchheader() def close(self): - self._ensure_header_written() - if self._nframeswritten != self._nframes or \ - self._datalength != self._datawritten: - self._patchheader() - self._file.flush() - if self._opened and self._file: - self._file.close() - self._file = None + if self._file: + try: + self._ensure_header_written() + if self._nframeswritten != self._nframes or \ + self._datalength != self._datawritten: + self._patchheader() + self._file.flush() + finally: + if self._opened and self._file: + self._file.close() + self._file = None # # private methods diff --git a/Lib/test/test_sunau.py b/Lib/test/test_sunau.py index c381d07a31b..1b0baeeb647 100644 --- a/Lib/test/test_sunau.py +++ b/Lib/test/test_sunau.py @@ -1,4 +1,4 @@ -from test.support import run_unittest, TESTFN +from test.support import TESTFN, unlink import unittest import pickle import os @@ -18,10 +18,7 @@ class SunAUTest(unittest.TestCase): def tearDown(self): if self.f is not None: self.f.close() - try: - os.remove(TESTFN) - except OSError: - pass + unlink(TESTFN) def test_lin(self): self.f = sunau.open(TESTFN, 'w') @@ -84,9 +81,49 @@ class SunAUTest(unittest.TestCase): dump = pickle.dumps(params) self.assertEqual(pickle.loads(dump), params) + def test_write_context_manager_calls_close(self): + # Close checks for a minimum header and will raise an error + # if it is not set, so this proves that close is called. + with self.assertRaises(sunau.Error): + with sunau.open(TESTFN, 'wb') as f: + pass + with self.assertRaises(sunau.Error): + with open(TESTFN, 'wb') as testfile: + with sunau.open(testfile): + pass + + def test_context_manager_with_open_file(self): + with open(TESTFN, 'wb') as testfile: + with sunau.open(testfile) as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + self.assertFalse(testfile.closed) + with open(TESTFN, 'rb') as testfile: + with sunau.open(testfile) as f: + self.assertFalse(f.getfp().closed) + params = f.getparams() + self.assertEqual(params[0], nchannels) + self.assertEqual(params[1], sampwidth) + self.assertEqual(params[2], framerate) + self.assertIsNone(f.getfp()) + self.assertFalse(testfile.closed) + + def test_context_manager_with_filename(self): + # If the file doesn't get closed, this test won't fail, but it will + # produce a resource leak warning. + with sunau.open(TESTFN, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + with sunau.open(TESTFN) as f: + self.assertFalse(f.getfp().closed) + params = f.getparams() + self.assertEqual(params[0], nchannels) + self.assertEqual(params[1], sampwidth) + self.assertEqual(params[2], framerate) + self.assertIsNone(f.getfp()) -def test_main(): - run_unittest(SunAUTest) if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS b/Misc/NEWS index d421cbd56a0..5d6489f8871 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -54,6 +54,9 @@ Core and Builtins Library ------- +- Issue #18878: sunau.open now supports the context manager protocol. Based on + patches by Claudiu Popa and R. David Murray. + - Issue #18909: Fix _tkinter.tkapp.interpaddr() on Windows 64-bit, don't cast 64-bit pointer to long (32 bits).