diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index df35abfcea0..da71fa83bcd 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -74,31 +74,51 @@ class TestShutil(unittest.TestCase): except: pass - def test_copytree_simple(self): + def write_data(path, data): + f = open(path, "w") + f.write(data) + f.close() + + def read_data(path): + f = open(path) + data = f.read() + f.close() + return data + src_dir = tempfile.mkdtemp() dst_dir = os.path.join(tempfile.mkdtemp(), 'destination') - open(os.path.join(src_dir, 'test.txt'), 'w').write('123') + + write_data(os.path.join(src_dir, 'test.txt'), '123') + os.mkdir(os.path.join(src_dir, 'test_dir')) - open(os.path.join(src_dir, 'test_dir', 'test.txt'), 'w').write('456') - # + write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456') + try: shutil.copytree(src_dir, dst_dir) self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test.txt'))) self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'test_dir'))) - self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir', 'test.txt'))) - self.assertEqual(open(os.path.join(dst_dir, 'test.txt')).read(), '123') - self.assertEqual(open(os.path.join(dst_dir, 'test_dir', 'test.txt')).read(), '456') + self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir', + 'test.txt'))) + actual = read_data(os.path.join(dst_dir, 'test.txt')) + self.assertEqual(actual, '123') + actual = read_data(os.path.join(dst_dir, 'test_dir', 'test.txt')) + self.assertEqual(actual, '456') finally: - try: - os.remove(os.path.join(src_dir, 'test.txt')) - os.remove(os.path.join(dst_dir, 'test.txt')) - os.remove(os.path.join(src_dir, 'test_dir', 'test.txt')) - os.remove(os.path.join(dst_dir, 'test_dir', 'test.txt')) - os.removedirs(src_dir) - os.removedirs(dst_dir) - except: - pass + for path in ( + os.path.join(src_dir, 'test.txt'), + os.path.join(dst_dir, 'test.txt'), + os.path.join(src_dir, 'test_dir', 'test.txt'), + os.path.join(dst_dir, 'test_dir', 'test.txt'), + ): + if os.path.exists(path): + os.remove(path) + for path in ( + os.path.join(src_dir, 'test_dir'), + os.path.join(dst_dir, 'test_dir'), + ): + if os.path.exists(path): + os.removedirs(path) if hasattr(os, "symlink"):