GH-73991: Use same signature for `shutil._rmtree_[un]safe()`. (#120517)

Preparatory work for moving `_rmtree_unsafe()` and `_rmtree_safe_fd()` to
`pathlib._os` so that they can be used from both `shutil` and `pathlib`.

Move implementation-specific setup from `rmtree()` into the safe/unsafe
functions, and give them the same signature `(path, dir_fd, onexc)`.

In the tests, mock `os.open` rather than `_rmtree_safe_fd()` to ensure the
FD-based walk is used, and replace a couple references to
`shutil._use_fd_functions` with `shutil.rmtree.avoids_symlink_attacks`
(which has the same value).

No change of behaviour.
This commit is contained in:
Barney Gale 2024-06-18 22:15:18 +01:00 committed by GitHub
parent 49f51deeef
commit 69058e20e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 45 deletions

View File

@ -605,7 +605,22 @@ else:
return stat.S_ISLNK(st.st_mode)
# version vulnerable to race conditions
def _rmtree_unsafe(path, onexc):
def _rmtree_unsafe(path, dir_fd, onexc):
if dir_fd is not None:
raise NotImplementedError("dir_fd unavailable on this platform")
try:
st = os.lstat(path)
except OSError as err:
onexc(os.lstat, path, err)
return
try:
if _rmtree_islink(st):
# symlinks to directories are forbidden, see bug #1669
raise OSError("Cannot call rmtree on a symbolic link")
except OSError as err:
onexc(os.path.islink, path, err)
# can't continue even if onexc hook returns
return
def onerror(err):
if not isinstance(err, FileNotFoundError):
onexc(os.scandir, err.filename, err)
@ -635,7 +650,26 @@ def _rmtree_unsafe(path, onexc):
onexc(os.rmdir, path, err)
# Version using fd-based APIs to protect against races
def _rmtree_safe_fd(stack, onexc):
def _rmtree_safe_fd(path, dir_fd, onexc):
# While the unsafe rmtree works fine on bytes, the fd based does not.
if isinstance(path, bytes):
path = os.fsdecode(path)
stack = [(os.lstat, dir_fd, path, None)]
try:
while stack:
_rmtree_safe_fd_step(stack, onexc)
finally:
# Close any file descriptors still on the stack.
while stack:
func, fd, path, entry = stack.pop()
if func is not os.close:
continue
try:
os.close(fd)
except OSError as err:
onexc(os.close, path, err)
def _rmtree_safe_fd_step(stack, onexc):
# Each stack item has four elements:
# * func: The first operation to perform: os.lstat, os.close or os.rmdir.
# Walking a directory starts with an os.lstat() to detect symlinks; in
@ -710,6 +744,7 @@ _use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <=
os.supports_dir_fd and
os.scandir in os.supports_fd and
os.stat in os.supports_follow_symlinks)
_rmtree_impl = _rmtree_safe_fd if _use_fd_functions else _rmtree_unsafe
def rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None):
"""Recursively delete a directory tree.
@ -753,41 +788,7 @@ def rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None):
exc_info = type(exc), exc, exc.__traceback__
return onerror(func, path, exc_info)
if _use_fd_functions:
# While the unsafe rmtree works fine on bytes, the fd based does not.
if isinstance(path, bytes):
path = os.fsdecode(path)
stack = [(os.lstat, dir_fd, path, None)]
try:
while stack:
_rmtree_safe_fd(stack, onexc)
finally:
# Close any file descriptors still on the stack.
while stack:
func, fd, path, entry = stack.pop()
if func is not os.close:
continue
try:
os.close(fd)
except OSError as err:
onexc(os.close, path, err)
else:
if dir_fd is not None:
raise NotImplementedError("dir_fd unavailable on this platform")
try:
st = os.lstat(path)
except OSError as err:
onexc(os.lstat, path, err)
return
try:
if _rmtree_islink(st):
# symlinks to directories are forbidden, see bug #1669
raise OSError("Cannot call rmtree on a symbolic link")
except OSError as err:
onexc(os.path.islink, path, err)
# can't continue even if onexc hook returns
return
return _rmtree_unsafe(path, onexc)
_rmtree_impl(path, dir_fd, onexc)
# Allow introspection of whether or not the hardening against symlink
# attacks is supported on the current platform

View File

@ -558,25 +558,23 @@ class TestRmTree(BaseTest, unittest.TestCase):
os.listdir in os.supports_fd and
os.stat in os.supports_follow_symlinks)
if _use_fd_functions:
self.assertTrue(shutil._use_fd_functions)
self.assertTrue(shutil.rmtree.avoids_symlink_attacks)
tmp_dir = self.mkdtemp()
d = os.path.join(tmp_dir, 'a')
os.mkdir(d)
try:
real_rmtree = shutil._rmtree_safe_fd
real_open = os.open
class Called(Exception): pass
def _raiser(*args, **kwargs):
raise Called
shutil._rmtree_safe_fd = _raiser
os.open = _raiser
self.assertRaises(Called, shutil.rmtree, d)
finally:
shutil._rmtree_safe_fd = real_rmtree
os.open = real_open
else:
self.assertFalse(shutil._use_fd_functions)
self.assertFalse(shutil.rmtree.avoids_symlink_attacks)
@unittest.skipUnless(shutil._use_fd_functions, "requires safe rmtree")
@unittest.skipUnless(shutil.rmtree.avoids_symlink_attacks, "requires safe rmtree")
def test_rmtree_fails_on_close(self):
# Test that the error handler is called for failed os.close() and that
# os.close() is only called once for a file descriptor.
@ -611,7 +609,7 @@ class TestRmTree(BaseTest, unittest.TestCase):
self.assertEqual(errors[1][1], dir1)
self.assertEqual(close_count, 2)
@unittest.skipUnless(shutil._use_fd_functions, "dir_fd is not supported")
@unittest.skipUnless(shutil.rmtree.avoids_symlink_attacks, "dir_fd is not supported")
def test_rmtree_with_dir_fd(self):
tmp_dir = self.mkdtemp()
victim = 'killme'
@ -625,7 +623,7 @@ class TestRmTree(BaseTest, unittest.TestCase):
shutil.rmtree(victim, dir_fd=dir_fd)
self.assertFalse(os.path.exists(fullname))
@unittest.skipIf(shutil._use_fd_functions, "dir_fd is supported")
@unittest.skipIf(shutil.rmtree.avoids_symlink_attacks, "dir_fd is supported")
def test_rmtree_with_dir_fd_unsupported(self):
tmp_dir = self.mkdtemp()
with self.assertRaises(NotImplementedError):