mirror of https://github.com/python/cpython
gh-121210: handle nodes with missing attributes/fields in `ast.compare` (#121211)
This commit is contained in:
parent
7a807c3efa
commit
15232a0819
19
Lib/ast.py
19
Lib/ast.py
|
@ -422,6 +422,8 @@ def compare(
|
||||||
might differ in whitespace or similar details.
|
might differ in whitespace or similar details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sentinel = object() # handle the possibility of a missing attribute/field
|
||||||
|
|
||||||
def _compare(a, b):
|
def _compare(a, b):
|
||||||
# Compare two fields on an AST object, which may themselves be
|
# Compare two fields on an AST object, which may themselves be
|
||||||
# AST objects, lists of AST objects, or primitive ASDL types
|
# AST objects, lists of AST objects, or primitive ASDL types
|
||||||
|
@ -449,8 +451,14 @@ def compare(
|
||||||
if a._fields != b._fields:
|
if a._fields != b._fields:
|
||||||
return False
|
return False
|
||||||
for field in a._fields:
|
for field in a._fields:
|
||||||
a_field = getattr(a, field)
|
a_field = getattr(a, field, sentinel)
|
||||||
b_field = getattr(b, field)
|
b_field = getattr(b, field, sentinel)
|
||||||
|
if a_field is sentinel and b_field is sentinel:
|
||||||
|
# both nodes are missing a field at runtime
|
||||||
|
continue
|
||||||
|
if a_field is sentinel or b_field is sentinel:
|
||||||
|
# one of the node is missing a field
|
||||||
|
return False
|
||||||
if not _compare(a_field, b_field):
|
if not _compare(a_field, b_field):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -461,8 +469,11 @@ def compare(
|
||||||
return False
|
return False
|
||||||
# Attributes are always ints.
|
# Attributes are always ints.
|
||||||
for attr in a._attributes:
|
for attr in a._attributes:
|
||||||
a_attr = getattr(a, attr)
|
a_attr = getattr(a, attr, sentinel)
|
||||||
b_attr = getattr(b, attr)
|
b_attr = getattr(b, attr, sentinel)
|
||||||
|
if a_attr is sentinel and b_attr is sentinel:
|
||||||
|
# both nodes are missing an attribute at runtime
|
||||||
|
continue
|
||||||
if a_attr != b_attr:
|
if a_attr != b_attr:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -948,6 +948,15 @@ class AST_Tests(unittest.TestCase):
|
||||||
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
|
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
|
||||||
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
|
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
|
||||||
|
|
||||||
|
# test that missing runtime fields is handled in ast.compare()
|
||||||
|
a1, a2 = ast.Name('a'), ast.Name('a')
|
||||||
|
self.assertTrue(ast.compare(a1, a2))
|
||||||
|
self.assertTrue(ast.compare(a1, a2))
|
||||||
|
del a1.id
|
||||||
|
self.assertFalse(ast.compare(a1, a2))
|
||||||
|
del a2.id
|
||||||
|
self.assertTrue(ast.compare(a1, a2))
|
||||||
|
|
||||||
def test_compare_modes(self):
|
def test_compare_modes(self):
|
||||||
for mode, sources in (
|
for mode, sources in (
|
||||||
("exec", exec_tests),
|
("exec", exec_tests),
|
||||||
|
@ -970,6 +979,16 @@ class AST_Tests(unittest.TestCase):
|
||||||
self.assertTrue(ast.compare(a, b, compare_attributes=False))
|
self.assertTrue(ast.compare(a, b, compare_attributes=False))
|
||||||
self.assertFalse(ast.compare(a, b, compare_attributes=True))
|
self.assertFalse(ast.compare(a, b, compare_attributes=True))
|
||||||
|
|
||||||
|
def test_compare_attributes_option_missing_attribute(self):
|
||||||
|
# test that missing runtime attributes is handled in ast.compare()
|
||||||
|
a1, a2 = ast.Name('a', lineno=1), ast.Name('a', lineno=1)
|
||||||
|
self.assertTrue(ast.compare(a1, a2))
|
||||||
|
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
|
||||||
|
del a1.lineno
|
||||||
|
self.assertFalse(ast.compare(a1, a2, compare_attributes=True))
|
||||||
|
del a2.lineno
|
||||||
|
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
|
||||||
|
|
||||||
def test_positional_only_feature_version(self):
|
def test_positional_only_feature_version(self):
|
||||||
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
|
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
|
||||||
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
|
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Handle AST nodes with missing runtime fields or attributes in
|
||||||
|
:func:`ast.compare`. Patch by Bénédikt Tran.
|
Loading…
Reference in New Issue