From 605c29303133987a5a97f3ceb4a369e9771676b0 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 23 Sep 2010 20:15:14 +0000 Subject: [PATCH] Further tarfile / test_tarfile cleanup --- Lib/tarfile.py | 57 +++++++++++++++++++++++----------------- Lib/test/test_tarfile.py | 3 ++- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/Lib/tarfile.py b/Lib/tarfile.py index bfdba58efe9..40109cd4a19 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -410,28 +410,34 @@ class _Stream: self.pos = 0 self.closed = False - if comptype == "gz": - try: - import zlib - except ImportError: - raise CompressionError("zlib module is not available") - self.zlib = zlib - self.crc = zlib.crc32(b"") - if mode == "r": - self._init_read_gz() - else: - self._init_write_gz() + try: + if comptype == "gz": + try: + import zlib + except ImportError: + raise CompressionError("zlib module is not available") + self.zlib = zlib + self.crc = zlib.crc32(b"") + if mode == "r": + self._init_read_gz() + else: + self._init_write_gz() - if comptype == "bz2": - try: - import bz2 - except ImportError: - raise CompressionError("bz2 module is not available") - if mode == "r": - self.dbuf = b"" - self.cmp = bz2.BZ2Decompressor() - else: - self.cmp = bz2.BZ2Compressor() + if comptype == "bz2": + try: + import bz2 + except ImportError: + raise CompressionError("bz2 module is not available") + if mode == "r": + self.dbuf = b"" + self.cmp = bz2.BZ2Decompressor() + else: + self.cmp = bz2.BZ2Compressor() + except: + if not self._extfileobj: + self.fileobj.close() + self.closed = True + raise def __del__(self): if hasattr(self, "closed") and not self.closed: @@ -1729,9 +1735,12 @@ class TarFile(object): if filemode not in "rw": raise ValueError("mode must be 'r' or 'w'") - t = cls(name, filemode, - _Stream(name, filemode, comptype, fileobj, bufsize), - **kwargs) + stream = _Stream(name, filemode, comptype, fileobj, bufsize) + try: + t = cls(name, filemode, stream, **kwargs) + except: + stream.close() + raise t._extfileobj = False return t diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index a51b51253cd..3a217dc8150 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -423,7 +423,8 @@ class DetectReadTest(unittest.TestCase): def _testfunc_fileobj(self, name, mode): try: - tar = tarfile.open(name, mode, fileobj=open(name, "rb")) + with open(name, "rb") as f: + tar = tarfile.open(name, mode, fileobj=f) except tarfile.ReadError as e: self.fail() else: