gh-121210: handle nodes with missing attributes/fields in `ast.compare` (#121211)

This commit is contained in:
Bénédikt Tran 2024-07-02 12:53:17 +02:00 committed by GitHub
parent 7a807c3efa
commit 15232a0819
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 4 deletions

View File

@ -422,6 +422,8 @@ def compare(
might differ in whitespace or similar details. might differ in whitespace or similar details.
""" """
sentinel = object() # handle the possibility of a missing attribute/field
def _compare(a, b): def _compare(a, b):
# Compare two fields on an AST object, which may themselves be # Compare two fields on an AST object, which may themselves be
# AST objects, lists of AST objects, or primitive ASDL types # AST objects, lists of AST objects, or primitive ASDL types
@ -449,8 +451,14 @@ def compare(
if a._fields != b._fields: if a._fields != b._fields:
return False return False
for field in a._fields: for field in a._fields:
a_field = getattr(a, field) a_field = getattr(a, field, sentinel)
b_field = getattr(b, field) b_field = getattr(b, field, sentinel)
if a_field is sentinel and b_field is sentinel:
# both nodes are missing a field at runtime
continue
if a_field is sentinel or b_field is sentinel:
# one of the node is missing a field
return False
if not _compare(a_field, b_field): if not _compare(a_field, b_field):
return False return False
else: else:
@ -461,8 +469,11 @@ def compare(
return False return False
# Attributes are always ints. # Attributes are always ints.
for attr in a._attributes: for attr in a._attributes:
a_attr = getattr(a, attr) a_attr = getattr(a, attr, sentinel)
b_attr = getattr(b, attr) b_attr = getattr(b, attr, sentinel)
if a_attr is sentinel and b_attr is sentinel:
# both nodes are missing an attribute at runtime
continue
if a_attr != b_attr: if a_attr != b_attr:
return False return False
else: else:

View File

@ -948,6 +948,15 @@ class AST_Tests(unittest.TestCase):
self.assertTrue(ast.compare(ast.Add(), ast.Add())) self.assertTrue(ast.compare(ast.Add(), ast.Add()))
self.assertFalse(ast.compare(ast.Sub(), ast.Add())) self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
# test that missing runtime fields is handled in ast.compare()
a1, a2 = ast.Name('a'), ast.Name('a')
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2))
del a1.id
self.assertFalse(ast.compare(a1, a2))
del a2.id
self.assertTrue(ast.compare(a1, a2))
def test_compare_modes(self): def test_compare_modes(self):
for mode, sources in ( for mode, sources in (
("exec", exec_tests), ("exec", exec_tests),
@ -970,6 +979,16 @@ class AST_Tests(unittest.TestCase):
self.assertTrue(ast.compare(a, b, compare_attributes=False)) self.assertTrue(ast.compare(a, b, compare_attributes=False))
self.assertFalse(ast.compare(a, b, compare_attributes=True)) self.assertFalse(ast.compare(a, b, compare_attributes=True))
def test_compare_attributes_option_missing_attribute(self):
# test that missing runtime attributes is handled in ast.compare()
a1, a2 = ast.Name('a', lineno=1), ast.Name('a', lineno=1)
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
del a1.lineno
self.assertFalse(ast.compare(a1, a2, compare_attributes=True))
del a2.lineno
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
def test_positional_only_feature_version(self): def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8)) ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8)) ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))

View File

@ -0,0 +1,2 @@
Handle AST nodes with missing runtime fields or attributes in
:func:`ast.compare`. Patch by Bénédikt Tran.