cpython/Lib/test/test_ast/test_ast.py

3337 lines
128 KiB
Python

import ast
import builtins
import copy
import dis
import enum
import os
import re
import sys
import textwrap
import types
import unittest
import weakref
from textwrap import dedent
try:
import _testinternalcapi
except ImportError:
_testinternalcapi = None
from test import support
from test.support import os_helper, script_helper
from test.support.ast_helper import ASTTestMixin
from test.test_ast.utils import to_tuple
from test.test_ast.snippets import (
eval_tests, eval_results, exec_tests, exec_results, single_tests, single_results
)
STDLIB = os.path.dirname(ast.__file__)
STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
class AST_Tests(unittest.TestCase):
maxDiff = None
def _is_ast_node(self, name, node):
if not isinstance(node, type):
return False
if "ast" not in node.__module__:
return False
return name != 'AST' and name[0].isupper()
def _assertTrueorder(self, ast_node, parent_pos):
if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
return
if isinstance(ast_node, (ast.expr, ast.stmt, ast.excepthandler)):
node_pos = (ast_node.lineno, ast_node.col_offset)
self.assertGreaterEqual(node_pos, parent_pos)
parent_pos = (ast_node.lineno, ast_node.col_offset)
for name in ast_node._fields:
value = getattr(ast_node, name)
if isinstance(value, list):
first_pos = parent_pos
if value and name == 'decorator_list':
first_pos = (value[0].lineno, value[0].col_offset)
for child in value:
self._assertTrueorder(child, first_pos)
elif value is not None:
self._assertTrueorder(value, parent_pos)
self.assertEqual(ast_node._fields, ast_node.__match_args__)
def test_AST_objects(self):
x = ast.AST()
self.assertEqual(x._fields, ())
x.foobar = 42
self.assertEqual(x.foobar, 42)
self.assertEqual(x.__dict__["foobar"], 42)
with self.assertRaises(AttributeError):
x.vararg
with self.assertRaises(TypeError):
# "ast.AST constructor takes 0 positional arguments"
ast.AST(2)
def test_AST_garbage_collection(self):
class X:
pass
a = ast.AST()
a.x = X()
a.x.a = a
ref = weakref.ref(a.x)
del a
support.gc_collect()
self.assertIsNone(ref())
def test_snippets(self):
for input, output, kind in ((exec_tests, exec_results, "exec"),
(single_tests, single_results, "single"),
(eval_tests, eval_results, "eval")):
for i, o in zip(input, output):
with self.subTest(action="parsing", input=i):
ast_tree = compile(i, "?", kind, ast.PyCF_ONLY_AST)
self.assertEqual(to_tuple(ast_tree), o)
self._assertTrueorder(ast_tree, (0, 0))
with self.subTest(action="compiling", input=i, kind=kind):
compile(ast_tree, "?", kind)
def test_ast_validation(self):
# compile() is the only function that calls PyAST_Validate
snippets_to_validate = exec_tests + single_tests + eval_tests
for snippet in snippets_to_validate:
tree = ast.parse(snippet)
compile(tree, '<string>', 'exec')
def test_optimization_levels__debug__(self):
cases = [(-1, '__debug__'), (0, '__debug__'), (1, False), (2, False)]
for (optval, expected) in cases:
with self.subTest(optval=optval, expected=expected):
res1 = ast.parse("__debug__", optimize=optval)
res2 = ast.parse(ast.parse("__debug__"), optimize=optval)
for res in [res1, res2]:
self.assertIsInstance(res.body[0], ast.Expr)
if isinstance(expected, bool):
self.assertIsInstance(res.body[0].value, ast.Constant)
self.assertEqual(res.body[0].value.value, expected)
else:
self.assertIsInstance(res.body[0].value, ast.Name)
self.assertEqual(res.body[0].value.id, expected)
def test_optimization_levels_const_folding(self):
folded = ('Expr', (1, 0, 1, 5), ('Constant', (1, 0, 1, 5), 3, None))
not_folded = ('Expr', (1, 0, 1, 5),
('BinOp', (1, 0, 1, 5),
('Constant', (1, 0, 1, 1), 1, None),
('Add',),
('Constant', (1, 4, 1, 5), 2, None)))
cases = [(-1, not_folded), (0, not_folded), (1, folded), (2, folded)]
for (optval, expected) in cases:
with self.subTest(optval=optval):
tree1 = ast.parse("1 + 2", optimize=optval)
tree2 = ast.parse(ast.parse("1 + 2"), optimize=optval)
for tree in [tree1, tree2]:
res = to_tuple(tree.body[0])
self.assertEqual(res, expected)
def test_invalid_position_information(self):
invalid_linenos = [
(10, 1), (-10, -11), (10, -11), (-5, -2), (-5, 1)
]
for lineno, end_lineno in invalid_linenos:
with self.subTest(f"Check invalid linenos {lineno}:{end_lineno}"):
snippet = "a = 1"
tree = ast.parse(snippet)
tree.body[0].lineno = lineno
tree.body[0].end_lineno = end_lineno
with self.assertRaises(ValueError):
compile(tree, '<string>', 'exec')
invalid_col_offsets = [
(10, 1), (-10, -11), (10, -11), (-5, -2), (-5, 1)
]
for col_offset, end_col_offset in invalid_col_offsets:
with self.subTest(f"Check invalid col_offset {col_offset}:{end_col_offset}"):
snippet = "a = 1"
tree = ast.parse(snippet)
tree.body[0].col_offset = col_offset
tree.body[0].end_col_offset = end_col_offset
with self.assertRaises(ValueError):
compile(tree, '<string>', 'exec')
def test_compilation_of_ast_nodes_with_default_end_position_values(self):
tree = ast.Module(body=[
ast.Import(names=[ast.alias(name='builtins', lineno=1, col_offset=0)], lineno=1, col_offset=0),
ast.Import(names=[ast.alias(name='traceback', lineno=0, col_offset=0)], lineno=0, col_offset=1)
], type_ignores=[])
# Check that compilation doesn't crash. Note: this may crash explicitly only on debug mode.
compile(tree, "<string>", "exec")
def test_slice(self):
slc = ast.parse("x[::]").body[0].value.slice
self.assertIsNone(slc.upper)
self.assertIsNone(slc.lower)
self.assertIsNone(slc.step)
def test_from_import(self):
im = ast.parse("from . import y").body[0]
self.assertIsNone(im.module)
def test_non_interned_future_from_ast(self):
mod = ast.parse("from __future__ import division")
self.assertIsInstance(mod.body[0], ast.ImportFrom)
mod.body[0].module = " __future__ ".strip()
compile(mod, "<test>", "exec")
def test_alias(self):
im = ast.parse("from bar import y").body[0]
self.assertEqual(len(im.names), 1)
alias = im.names[0]
self.assertEqual(alias.name, 'y')
self.assertIsNone(alias.asname)
self.assertEqual(alias.lineno, 1)
self.assertEqual(alias.end_lineno, 1)
self.assertEqual(alias.col_offset, 16)
self.assertEqual(alias.end_col_offset, 17)
im = ast.parse("from bar import *").body[0]
alias = im.names[0]
self.assertEqual(alias.name, '*')
self.assertIsNone(alias.asname)
self.assertEqual(alias.lineno, 1)
self.assertEqual(alias.end_lineno, 1)
self.assertEqual(alias.col_offset, 16)
self.assertEqual(alias.end_col_offset, 17)
im = ast.parse("from bar import y as z").body[0]
alias = im.names[0]
self.assertEqual(alias.name, "y")
self.assertEqual(alias.asname, "z")
self.assertEqual(alias.lineno, 1)
self.assertEqual(alias.end_lineno, 1)
self.assertEqual(alias.col_offset, 16)
self.assertEqual(alias.end_col_offset, 22)
im = ast.parse("import bar as foo").body[0]
alias = im.names[0]
self.assertEqual(alias.name, "bar")
self.assertEqual(alias.asname, "foo")
self.assertEqual(alias.lineno, 1)
self.assertEqual(alias.end_lineno, 1)
self.assertEqual(alias.col_offset, 7)
self.assertEqual(alias.end_col_offset, 17)
def test_base_classes(self):
self.assertTrue(issubclass(ast.For, ast.stmt))
self.assertTrue(issubclass(ast.Name, ast.expr))
self.assertTrue(issubclass(ast.stmt, ast.AST))
self.assertTrue(issubclass(ast.expr, ast.AST))
self.assertTrue(issubclass(ast.comprehension, ast.AST))
self.assertTrue(issubclass(ast.Gt, ast.AST))
def test_field_attr_existence(self):
for name, item in ast.__dict__.items():
# constructor has a different signature
if name == 'Index':
continue
if self._is_ast_node(name, item):
x = self._construct_ast_class(item)
if isinstance(x, ast.AST):
self.assertIs(type(x._fields), tuple)
def _construct_ast_class(self, cls):
kwargs = {}
for name, typ in cls.__annotations__.items():
if typ is str:
kwargs[name] = 'capybara'
elif typ is int:
kwargs[name] = 42
elif typ is object:
kwargs[name] = b'capybara'
elif isinstance(typ, type) and issubclass(typ, ast.AST):
kwargs[name] = self._construct_ast_class(typ)
return cls(**kwargs)
def test_arguments(self):
x = ast.arguments()
self.assertEqual(x._fields, ('posonlyargs', 'args', 'vararg', 'kwonlyargs',
'kw_defaults', 'kwarg', 'defaults'))
self.assertEqual(x.__annotations__, {
'posonlyargs': list[ast.arg],
'args': list[ast.arg],
'vararg': ast.arg | None,
'kwonlyargs': list[ast.arg],
'kw_defaults': list[ast.expr],
'kwarg': ast.arg | None,
'defaults': list[ast.expr],
})
self.assertEqual(x.args, [])
self.assertIsNone(x.vararg)
x = ast.arguments(*range(1, 8))
self.assertEqual(x.args, 2)
self.assertEqual(x.vararg, 3)
def test_field_attr_writable(self):
x = ast.Constant(1)
# We can assign to _fields
x._fields = 666
self.assertEqual(x._fields, 666)
def test_classattrs(self):
with self.assertWarns(DeprecationWarning):
x = ast.Constant()
self.assertEqual(x._fields, ('value', 'kind'))
with self.assertRaises(AttributeError):
x.value
x = ast.Constant(42)
self.assertEqual(x.value, 42)
with self.assertRaises(AttributeError):
x.lineno
with self.assertRaises(AttributeError):
x.foobar
x = ast.Constant(lineno=2, value=3)
self.assertEqual(x.lineno, 2)
x = ast.Constant(42, lineno=0)
self.assertEqual(x.lineno, 0)
self.assertEqual(x._fields, ('value', 'kind'))
self.assertEqual(x.value, 42)
self.assertRaises(TypeError, ast.Constant, 1, None, 2)
self.assertRaises(TypeError, ast.Constant, 1, None, 2, lineno=0)
# Arbitrary keyword arguments are supported (but deprecated)
with self.assertWarns(DeprecationWarning):
self.assertEqual(ast.Constant(1, foo='bar').foo, 'bar')
with self.assertRaisesRegex(TypeError, "Constant got multiple values for argument 'value'"):
ast.Constant(1, value=2)
self.assertEqual(ast.Constant(42).value, 42)
self.assertEqual(ast.Constant(4.25).value, 4.25)
self.assertEqual(ast.Constant(4.25j).value, 4.25j)
self.assertEqual(ast.Constant('42').value, '42')
self.assertEqual(ast.Constant(b'42').value, b'42')
self.assertIs(ast.Constant(True).value, True)
self.assertIs(ast.Constant(False).value, False)
self.assertIs(ast.Constant(None).value, None)
self.assertIs(ast.Constant(...).value, ...)
def test_constant_subclasses(self):
class N(ast.Constant):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.z = 'spam'
class N2(ast.Constant):
pass
n = N(42)
self.assertEqual(n.value, 42)
self.assertEqual(n.z, 'spam')
self.assertEqual(type(n), N)
self.assertTrue(isinstance(n, N))
self.assertTrue(isinstance(n, ast.Constant))
self.assertFalse(isinstance(n, N2))
self.assertFalse(isinstance(ast.Constant(42), N))
n = N(value=42)
self.assertEqual(n.value, 42)
self.assertEqual(type(n), N)
def test_module(self):
body = [ast.Constant(42)]
x = ast.Module(body, [])
self.assertEqual(x.body, body)
def test_nodeclasses(self):
# Zero arguments constructor explicitly allowed (but deprecated)
with self.assertWarns(DeprecationWarning):
x = ast.BinOp()
self.assertEqual(x._fields, ('left', 'op', 'right'))
# Random attribute allowed too
x.foobarbaz = 5
self.assertEqual(x.foobarbaz, 5)
n1 = ast.Constant(1)
n3 = ast.Constant(3)
addop = ast.Add()
x = ast.BinOp(n1, addop, n3)
self.assertEqual(x.left, n1)
self.assertEqual(x.op, addop)
self.assertEqual(x.right, n3)
x = ast.BinOp(1, 2, 3)
self.assertEqual(x.left, 1)
self.assertEqual(x.op, 2)
self.assertEqual(x.right, 3)
x = ast.BinOp(1, 2, 3, lineno=0)
self.assertEqual(x.left, 1)
self.assertEqual(x.op, 2)
self.assertEqual(x.right, 3)
self.assertEqual(x.lineno, 0)
# node raises exception when given too many arguments
self.assertRaises(TypeError, ast.BinOp, 1, 2, 3, 4)
# node raises exception when given too many arguments
self.assertRaises(TypeError, ast.BinOp, 1, 2, 3, 4, lineno=0)
# can set attributes through kwargs too
x = ast.BinOp(left=1, op=2, right=3, lineno=0)
self.assertEqual(x.left, 1)
self.assertEqual(x.op, 2)
self.assertEqual(x.right, 3)
self.assertEqual(x.lineno, 0)
# Random kwargs also allowed (but deprecated)
with self.assertWarns(DeprecationWarning):
x = ast.BinOp(1, 2, 3, foobarbaz=42)
self.assertEqual(x.foobarbaz, 42)
def test_no_fields(self):
# this used to fail because Sub._fields was None
x = ast.Sub()
self.assertEqual(x._fields, ())
def test_invalid_sum(self):
pos = dict(lineno=2, col_offset=3)
m = ast.Module([ast.Expr(ast.expr(**pos), **pos)], [])
with self.assertRaises(TypeError) as cm:
compile(m, "<test>", "exec")
self.assertIn("but got <ast.expr", str(cm.exception))
def test_invalid_identifier(self):
m = ast.Module([ast.Expr(ast.Name(42, ast.Load()))], [])
ast.fix_missing_locations(m)
with self.assertRaises(TypeError) as cm:
compile(m, "<test>", "exec")
self.assertIn("identifier must be of type str", str(cm.exception))
def test_invalid_constant(self):
for invalid_constant in int, (1, 2, int), frozenset((1, 2, int)):
e = ast.Expression(body=ast.Constant(invalid_constant))
ast.fix_missing_locations(e)
with self.assertRaisesRegex(
TypeError, "invalid type in Constant: type"
):
compile(e, "<test>", "eval")
def test_empty_yield_from(self):
# Issue 16546: yield from value is not optional.
empty_yield_from = ast.parse("def f():\n yield from g()")
empty_yield_from.body[0].body[0].value.value = None
with self.assertRaises(ValueError) as cm:
compile(empty_yield_from, "<test>", "exec")
self.assertIn("field 'value' is required", str(cm.exception))
@support.cpython_only
def test_issue31592(self):
# There shouldn't be an assertion failure in case of a bad
# unicodedata.normalize().
import unicodedata
def bad_normalize(*args):
return None
with support.swap_attr(unicodedata, 'normalize', bad_normalize):
self.assertRaises(TypeError, ast.parse, '\u03D5')
def test_issue18374_binop_col_offset(self):
tree = ast.parse('4+5+6+7')
parent_binop = tree.body[0].value
child_binop = parent_binop.left
grandchild_binop = child_binop.left
self.assertEqual(parent_binop.col_offset, 0)
self.assertEqual(parent_binop.end_col_offset, 7)
self.assertEqual(child_binop.col_offset, 0)
self.assertEqual(child_binop.end_col_offset, 5)
self.assertEqual(grandchild_binop.col_offset, 0)
self.assertEqual(grandchild_binop.end_col_offset, 3)
tree = ast.parse('4+5-\\\n 6-7')
parent_binop = tree.body[0].value
child_binop = parent_binop.left
grandchild_binop = child_binop.left
self.assertEqual(parent_binop.col_offset, 0)
self.assertEqual(parent_binop.lineno, 1)
self.assertEqual(parent_binop.end_col_offset, 4)
self.assertEqual(parent_binop.end_lineno, 2)
self.assertEqual(child_binop.col_offset, 0)
self.assertEqual(child_binop.lineno, 1)
self.assertEqual(child_binop.end_col_offset, 2)
self.assertEqual(child_binop.end_lineno, 2)
self.assertEqual(grandchild_binop.col_offset, 0)
self.assertEqual(grandchild_binop.lineno, 1)
self.assertEqual(grandchild_binop.end_col_offset, 3)
self.assertEqual(grandchild_binop.end_lineno, 1)
def test_issue39579_dotted_name_end_col_offset(self):
tree = ast.parse('@a.b.c\ndef f(): pass')
attr_b = tree.body[0].decorator_list[0].value
self.assertEqual(attr_b.end_col_offset, 4)
def test_ast_asdl_signature(self):
self.assertEqual(ast.withitem.__doc__, "withitem(expr context_expr, expr? optional_vars)")
self.assertEqual(ast.GtE.__doc__, "GtE")
self.assertEqual(ast.Name.__doc__, "Name(identifier id, expr_context ctx)")
self.assertEqual(ast.cmpop.__doc__, "cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn")
expressions = [f" | {node.__doc__}" for node in ast.expr.__subclasses__()]
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
def test_compare_basics(self):
self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
self.assertFalse(
ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
)
def test_compare_modified_ast(self):
# The ast API is a bit underspecified. The objects are mutable,
# and even _fields and _attributes are mutable. The compare() does
# some simple things to accommodate mutability.
a = ast.parse("m * x + b", mode="eval")
b = ast.parse("m * x + b", mode="eval")
self.assertTrue(ast.compare(a, b))
a._fields = a._fields + ("spam",)
a.spam = "Spam"
self.assertNotEqual(a._fields, b._fields)
self.assertFalse(ast.compare(a, b))
self.assertFalse(ast.compare(b, a))
b._fields = a._fields
b.spam = a.spam
self.assertTrue(ast.compare(a, b))
self.assertTrue(ast.compare(b, a))
b._attributes = b._attributes + ("eggs",)
b.eggs = "eggs"
self.assertNotEqual(a._attributes, b._attributes)
self.assertFalse(ast.compare(a, b, compare_attributes=True))
self.assertFalse(ast.compare(b, a, compare_attributes=True))
a._attributes = b._attributes
a.eggs = b.eggs
self.assertTrue(ast.compare(a, b, compare_attributes=True))
self.assertTrue(ast.compare(b, a, compare_attributes=True))
def test_compare_literals(self):
constants = (
-20,
20,
20.0,
1,
1.0,
True,
0,
False,
frozenset(),
tuple(),
"ABCD",
"abcd",
"中文字",
1e1000,
-1e1000,
)
for next_index, constant in enumerate(constants[:-1], 1):
next_constant = constants[next_index]
with self.subTest(literal=constant, next_literal=next_constant):
self.assertTrue(
ast.compare(ast.Constant(constant), ast.Constant(constant))
)
self.assertFalse(
ast.compare(
ast.Constant(constant), ast.Constant(next_constant)
)
)
same_looking_literal_cases = [
{1, 1.0, True, 1 + 0j},
{0, 0.0, False, 0 + 0j},
]
for same_looking_literals in same_looking_literal_cases:
for literal in same_looking_literals:
for same_looking_literal in same_looking_literals - {literal}:
self.assertFalse(
ast.compare(
ast.Constant(literal),
ast.Constant(same_looking_literal),
)
)
def test_compare_fieldless(self):
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
# test that missing runtime fields is handled in ast.compare()
a1, a2 = ast.Name('a'), ast.Name('a')
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2))
del a1.id
self.assertFalse(ast.compare(a1, a2))
del a2.id
self.assertTrue(ast.compare(a1, a2))
def test_compare_modes(self):
for mode, sources in (
("exec", exec_tests),
("eval", eval_tests),
("single", single_tests),
):
for source in sources:
a = ast.parse(source, mode=mode)
b = ast.parse(source, mode=mode)
self.assertTrue(
ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
)
def test_compare_attributes_option(self):
def parse(a, b):
return ast.parse(a), ast.parse(b)
a, b = parse("2 + 2", "2+2")
self.assertTrue(ast.compare(a, b))
self.assertTrue(ast.compare(a, b, compare_attributes=False))
self.assertFalse(ast.compare(a, b, compare_attributes=True))
def test_compare_attributes_option_missing_attribute(self):
# test that missing runtime attributes is handled in ast.compare()
a1, a2 = ast.Name('a', lineno=1), ast.Name('a', lineno=1)
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
del a1.lineno
self.assertFalse(ast.compare(a1, a2, compare_attributes=True))
del a2.lineno
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
with self.assertRaises(SyntaxError):
ast.parse('def foo(x, /): ...', feature_version=(3, 7))
with self.assertRaises(SyntaxError):
ast.parse('def bar(x=1, /): ...', feature_version=(3, 7))
ast.parse('lambda x, /: ...', feature_version=(3, 8))
ast.parse('lambda x=1, /: ...', feature_version=(3, 8))
with self.assertRaises(SyntaxError):
ast.parse('lambda x, /: ...', feature_version=(3, 7))
with self.assertRaises(SyntaxError):
ast.parse('lambda x=1, /: ...', feature_version=(3, 7))
def test_assignment_expression_feature_version(self):
ast.parse('(x := 0)', feature_version=(3, 8))
with self.assertRaises(SyntaxError):
ast.parse('(x := 0)', feature_version=(3, 7))
def test_conditional_context_managers_parse_with_low_feature_version(self):
# regression test for gh-115881
ast.parse('with (x() if y else z()): ...', feature_version=(3, 8))
def test_exception_groups_feature_version(self):
code = dedent('''
try: ...
except* Exception: ...
''')
ast.parse(code)
with self.assertRaises(SyntaxError):
ast.parse(code, feature_version=(3, 10))
def test_type_params_feature_version(self):
samples = [
"type X = int",
"class X[T]: pass",
"def f[T](): pass",
]
for sample in samples:
with self.subTest(sample):
ast.parse(sample)
with self.assertRaises(SyntaxError):
ast.parse(sample, feature_version=(3, 11))
def test_type_params_default_feature_version(self):
samples = [
"type X[*Ts=int] = int",
"class X[T=int]: pass",
"def f[**P=int](): pass",
]
for sample in samples:
with self.subTest(sample):
ast.parse(sample)
with self.assertRaises(SyntaxError):
ast.parse(sample, feature_version=(3, 12))
def test_invalid_major_feature_version(self):
with self.assertRaises(ValueError):
ast.parse('pass', feature_version=(2, 7))
with self.assertRaises(ValueError):
ast.parse('pass', feature_version=(4, 0))
def test_constant_as_name(self):
for constant in "True", "False", "None":
expr = ast.Expression(ast.Name(constant, ast.Load()))
ast.fix_missing_locations(expr)
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
compile(expr, "<test>", "eval")
def test_precedence_enum(self):
class _Precedence(enum.IntEnum):
"""Precedence table that originated from python grammar."""
NAMED_EXPR = enum.auto() # <target> := <expr1>
TUPLE = enum.auto() # <expr1>, <expr2>
YIELD = enum.auto() # 'yield', 'yield from'
TEST = enum.auto() # 'if'-'else', 'lambda'
OR = enum.auto() # 'or'
AND = enum.auto() # 'and'
NOT = enum.auto() # 'not'
CMP = enum.auto() # '<', '>', '==', '>=', '<=', '!=',
# 'in', 'not in', 'is', 'is not'
EXPR = enum.auto()
BOR = EXPR # '|'
BXOR = enum.auto() # '^'
BAND = enum.auto() # '&'
SHIFT = enum.auto() # '<<', '>>'
ARITH = enum.auto() # '+', '-'
TERM = enum.auto() # '*', '@', '/', '%', '//'
FACTOR = enum.auto() # unary '+', '-', '~'
POWER = enum.auto() # '**'
AWAIT = enum.auto() # 'await'
ATOM = enum.auto()
def next(self):
try:
return self.__class__(self + 1)
except ValueError:
return self
enum._test_simple_enum(_Precedence, ast._Precedence)
@support.cpython_only
def test_ast_recursion_limit(self):
fail_depth = support.exceeds_recursion_limit()
crash_depth = 100_000
success_depth = int(support.get_c_recursion_limit() * 0.8)
if _testinternalcapi is not None:
remaining = _testinternalcapi.get_c_recursion_remaining()
success_depth = min(success_depth, remaining)
def check_limit(prefix, repeated):
expect_ok = prefix + repeated * success_depth
ast.parse(expect_ok)
for depth in (fail_depth, crash_depth):
broken = prefix + repeated * depth
details = "Compiling ({!r} + {!r} * {})".format(
prefix, repeated, depth)
with self.assertRaises(RecursionError, msg=details):
with support.infinite_recursion():
ast.parse(broken)
check_limit("a", "()")
check_limit("a", ".b")
check_limit("a", "[0]")
check_limit("a", "*a")
def test_null_bytes(self):
with self.assertRaises(SyntaxError,
msg="source code string cannot contain null bytes"):
ast.parse("a\0b")
def assert_none_check(self, node: type[ast.AST], attr: str, source: str) -> None:
with self.subTest(f"{node.__name__}.{attr}"):
tree = ast.parse(source)
found = 0
for child in ast.walk(tree):
if isinstance(child, node):
setattr(child, attr, None)
found += 1
self.assertEqual(found, 1)
e = re.escape(f"field '{attr}' is required for {node.__name__}")
with self.assertRaisesRegex(ValueError, f"^{e}$"):
compile(tree, "<test>", "exec")
def test_none_checks(self) -> None:
tests = [
(ast.alias, "name", "import spam as SPAM"),
(ast.arg, "arg", "def spam(SPAM): spam"),
(ast.comprehension, "target", "[spam for SPAM in spam]"),
(ast.comprehension, "iter", "[spam for spam in SPAM]"),
(ast.keyword, "value", "spam(**SPAM)"),
(ast.match_case, "pattern", "match spam:\n case SPAM: spam"),
(ast.withitem, "context_expr", "with SPAM: spam"),
]
for node, attr, source in tests:
self.assert_none_check(node, attr, source)
class CopyTests(unittest.TestCase):
"""Test copying and pickling AST nodes."""
@staticmethod
def iter_ast_classes():
"""Iterate over the (native) subclasses of ast.AST recursively.
This excludes the special class ast.Index since its constructor
returns an integer.
"""
def do(cls):
if cls.__module__ != 'ast':
return
if cls is ast.Index:
return
yield cls
for sub in cls.__subclasses__():
yield from do(sub)
yield from do(ast.AST)
def test_pickling(self):
import pickle
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
for code in exec_tests:
with self.subTest(code=code, protocol=protocol):
tree = compile(code, "?", "exec", 0x400)
ast2 = pickle.loads(pickle.dumps(tree, protocol))
self.assertEqual(to_tuple(ast2), to_tuple(tree))
def test_copy_with_parents(self):
# gh-120108
code = """
('',)
while i < n:
if ch == '':
ch = format[i]
if ch == '':
if freplace is None:
'' % getattr(object)
elif ch == '':
if zreplace is None:
if hasattr:
offset = object.utcoffset()
if offset is not None:
if offset.days < 0:
offset = -offset
h = divmod(timedelta(hours=0))
if u:
zreplace = '' % (sign,)
elif s:
zreplace = '' % (sign,)
else:
zreplace = '' % (sign,)
elif ch == '':
if Zreplace is None:
Zreplace = ''
if hasattr(object):
s = object.tzname()
if s is not None:
Zreplace = s.replace('')
newformat.append(Zreplace)
else:
push('')
else:
push(ch)
"""
tree = ast.parse(textwrap.dedent(code))
for node in ast.walk(tree):
for child in ast.iter_child_nodes(node):
child.parent = node
try:
with support.infinite_recursion(200):
tree2 = copy.deepcopy(tree)
finally:
# Singletons like ast.Load() are shared; make sure we don't
# leave them mutated after this test.
for node in ast.walk(tree):
if hasattr(node, "parent"):
del node.parent
for node in ast.walk(tree2):
for child in ast.iter_child_nodes(node):
if hasattr(child, "parent") and not isinstance(child, (
ast.expr_context, ast.boolop, ast.unaryop, ast.cmpop, ast.operator,
)):
self.assertEqual(to_tuple(child.parent), to_tuple(node))
def test_replace_interface(self):
for klass in self.iter_ast_classes():
with self.subTest(klass=klass):
self.assertTrue(hasattr(klass, '__replace__'))
fields = set(klass._fields)
with self.subTest(klass=klass, fields=fields):
node = klass(**dict.fromkeys(fields))
# forbid positional arguments in replace()
self.assertRaises(TypeError, copy.replace, node, 1)
self.assertRaises(TypeError, node.__replace__, 1)
def test_replace_native(self):
for klass in self.iter_ast_classes():
fields = set(klass._fields)
attributes = set(klass._attributes)
with self.subTest(klass=klass, fields=fields, attributes=attributes):
# use of object() to ensure that '==' and 'is'
# behave similarly in ast.compare(node, repl)
old_fields = {field: object() for field in fields}
old_attrs = {attr: object() for attr in attributes}
# check shallow copy
node = klass(**old_fields)
repl = copy.replace(node)
self.assertTrue(ast.compare(node, repl, compare_attributes=True))
# check when passing using attributes (they may be optional!)
node = klass(**old_fields, **old_attrs)
repl = copy.replace(node)
self.assertTrue(ast.compare(node, repl, compare_attributes=True))
for field in fields:
# check when we sometimes have attributes and sometimes not
for init_attrs in [{}, old_attrs]:
node = klass(**old_fields, **init_attrs)
# only change a single field (do not change attributes)
new_value = object()
repl = copy.replace(node, **{field: new_value})
for f in fields:
old_value = old_fields[f]
# assert that there is no side-effect
self.assertIs(getattr(node, f), old_value)
# check the changes
if f != field:
self.assertIs(getattr(repl, f), old_value)
else:
self.assertIs(getattr(repl, f), new_value)
self.assertFalse(ast.compare(node, repl, compare_attributes=True))
for attribute in attributes:
node = klass(**old_fields, **old_attrs)
# only change a single attribute (do not change fields)
new_attr = object()
repl = copy.replace(node, **{attribute: new_attr})
for a in attributes:
old_attr = old_attrs[a]
# assert that there is no side-effect
self.assertIs(getattr(node, a), old_attr)
# check the changes
if a != attribute:
self.assertIs(getattr(repl, a), old_attr)
else:
self.assertIs(getattr(repl, a), new_attr)
self.assertFalse(ast.compare(node, repl, compare_attributes=True))
def test_replace_accept_known_class_fields(self):
nid, ctx = object(), object()
node = ast.Name(id=nid, ctx=ctx)
self.assertIs(node.id, nid)
self.assertIs(node.ctx, ctx)
new_nid = object()
repl = copy.replace(node, id=new_nid)
# assert that there is no side-effect
self.assertIs(node.id, nid)
self.assertIs(node.ctx, ctx)
# check the changes
self.assertIs(repl.id, new_nid)
self.assertIs(repl.ctx, node.ctx) # no changes
def test_replace_accept_known_class_attributes(self):
node = ast.parse('x').body[0].value
self.assertEqual(node.id, 'x')
self.assertEqual(node.lineno, 1)
# constructor allows any type so replace() should do the same
lineno = object()
repl = copy.replace(node, lineno=lineno)
# assert that there is no side-effect
self.assertEqual(node.lineno, 1)
# check the changes
self.assertEqual(repl.id, node.id)
self.assertEqual(repl.ctx, node.ctx)
self.assertEqual(repl.lineno, lineno)
_, _, state = node.__reduce__()
self.assertEqual(state['id'], 'x')
self.assertEqual(state['ctx'], node.ctx)
self.assertEqual(state['lineno'], 1)
_, _, state = repl.__reduce__()
self.assertEqual(state['id'], 'x')
self.assertEqual(state['ctx'], node.ctx)
self.assertEqual(state['lineno'], lineno)
def test_replace_accept_known_custom_class_fields(self):
class MyNode(ast.AST):
_fields = ('name', 'data')
__annotations__ = {'name': str, 'data': object}
__match_args__ = ('name', 'data')
name, data = 'name', object()
node = MyNode(name, data)
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check shallow copy
repl = copy.replace(node)
# assert that there is no side-effect
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check the shallow copy
self.assertIs(repl.name, name)
self.assertIs(repl.data, data)
node = MyNode(name, data)
repl_data = object()
# replace custom but known field
repl = copy.replace(node, data=repl_data)
# assert that there is no side-effect
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check the changes
self.assertIs(repl.name, node.name)
self.assertIs(repl.data, repl_data)
def test_replace_accept_known_custom_class_attributes(self):
class MyNode(ast.AST):
x = 0
y = 1
_attributes = ('x', 'y')
node = MyNode()
self.assertEqual(node.x, 0)
self.assertEqual(node.y, 1)
y = object()
repl = copy.replace(node, y=y)
# assert that there is no side-effect
self.assertEqual(node.x, 0)
self.assertEqual(node.y, 1)
# check the changes
self.assertEqual(repl.x, 0)
self.assertEqual(repl.y, y)
def test_replace_ignore_known_custom_instance_fields(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx
# assert initial values
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# shallow copy, but drops extra fields
repl = copy.replace(node)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# verify that the 'extra' field is not kept
self.assertIs(repl.id, 'x')
self.assertIs(repl.ctx, context)
self.assertRaises(AttributeError, getattr, repl, 'extra')
# change known native field
repl = copy.replace(node, id='y')
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# verify that the 'extra' field is not kept
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)
self.assertRaises(AttributeError, getattr, repl, 'extra')
def test_replace_reject_missing_field(self):
# case: warn if deleted field is not replaced
node = ast.parse('x').body[0].value
context = node.ctx
del node.id
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
msg = "Name.__replace__ missing 1 keyword argument: 'id'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node)
# assert that there is no side-effect
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
# case: do not raise if deleted field is replaced
node = ast.parse('x').body[0].value
context = node.ctx
del node.id
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
repl = copy.replace(node, id='y')
# assert that there is no side-effect
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)
def test_replace_reject_known_custom_instance_fields_commits(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx
# explicit rejection of known instance fields
self.assertTrue(hasattr(node, 'extra'))
msg = "Name.__replace__ got an unexpected keyword argument 'extra'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, extra=1)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
def test_replace_reject_unknown_instance_fields(self):
node = ast.parse('x').body[0].value
context = node.ctx
# explicit rejection of unknown extra fields
self.assertRaises(AttributeError, getattr, node, 'unknown')
msg = "Name.__replace__ got an unexpected keyword argument 'unknown'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, unknown=1)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertRaises(AttributeError, getattr, node, 'unknown')
class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
def test_parse(self):
a = ast.parse('foo(1 + 1)')
b = compile('foo(1 + 1)', '<unknown>', 'exec', ast.PyCF_ONLY_AST)
self.assertEqual(ast.dump(a), ast.dump(b))
def test_parse_in_error(self):
try:
1/0
except Exception:
with self.assertRaises(SyntaxError) as e:
ast.literal_eval(r"'\U'")
self.assertIsNotNone(e.exception.__context__)
def test_dump(self):
node = ast.parse('spam(eggs, "and cheese")')
self.assertEqual(ast.dump(node),
"Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), "
"args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
"Module([Expr(Call(Name('spam', Load()), [Name('eggs', Load()), "
"Constant('and cheese')]))])"
)
self.assertEqual(ast.dump(node, include_attributes=True),
"Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load(), "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=4), "
"args=[Name(id='eggs', ctx=Load(), lineno=1, col_offset=5, "
"end_lineno=1, end_col_offset=9), Constant(value='and cheese', "
"lineno=1, col_offset=11, end_lineno=1, end_col_offset=23)], "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=24), "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=24)])"
)
def test_dump_indent(self):
node = ast.parse('spam(eggs, "and cheese")')
self.assertEqual(ast.dump(node, indent=3), """\
Module(
body=[
Expr(
value=Call(
func=Name(id='spam', ctx=Load()),
args=[
Name(id='eggs', ctx=Load()),
Constant(value='and cheese')]))])""")
self.assertEqual(ast.dump(node, annotate_fields=False, indent='\t'), """\
Module(
\t[
\t\tExpr(
\t\t\tCall(
\t\t\t\tName('spam', Load()),
\t\t\t\t[
\t\t\t\t\tName('eggs', Load()),
\t\t\t\t\tConstant('and cheese')]))])""")
self.assertEqual(ast.dump(node, include_attributes=True, indent=3), """\
Module(
body=[
Expr(
value=Call(
func=Name(
id='spam',
ctx=Load(),
lineno=1,
col_offset=0,
end_lineno=1,
end_col_offset=4),
args=[
Name(
id='eggs',
ctx=Load(),
lineno=1,
col_offset=5,
end_lineno=1,
end_col_offset=9),
Constant(
value='and cheese',
lineno=1,
col_offset=11,
end_lineno=1,
end_col_offset=23)],
lineno=1,
col_offset=0,
end_lineno=1,
end_col_offset=24),
lineno=1,
col_offset=0,
end_lineno=1,
end_col_offset=24)])""")
def test_dump_incomplete(self):
node = ast.Raise(lineno=3, col_offset=4)
self.assertEqual(ast.dump(node),
"Raise()"
)
self.assertEqual(ast.dump(node, include_attributes=True),
"Raise(lineno=3, col_offset=4)"
)
node = ast.Raise(exc=ast.Name(id='e', ctx=ast.Load()), lineno=3, col_offset=4)
self.assertEqual(ast.dump(node),
"Raise(exc=Name(id='e', ctx=Load()))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
"Raise(Name('e', Load()))"
)
self.assertEqual(ast.dump(node, include_attributes=True),
"Raise(exc=Name(id='e', ctx=Load()), lineno=3, col_offset=4)"
)
self.assertEqual(ast.dump(node, annotate_fields=False, include_attributes=True),
"Raise(Name('e', Load()), lineno=3, col_offset=4)"
)
node = ast.Raise(cause=ast.Name(id='e', ctx=ast.Load()))
self.assertEqual(ast.dump(node),
"Raise(cause=Name(id='e', ctx=Load()))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
"Raise(cause=Name('e', Load()))"
)
# Arguments:
node = ast.arguments(args=[ast.arg("x")])
self.assertEqual(ast.dump(node, annotate_fields=False),
"arguments([], [arg('x')])",
)
node = ast.arguments(posonlyargs=[ast.arg("x")])
self.assertEqual(ast.dump(node, annotate_fields=False),
"arguments([arg('x')])",
)
node = ast.arguments(posonlyargs=[ast.arg("x")], kwonlyargs=[ast.arg('y')])
self.assertEqual(ast.dump(node, annotate_fields=False),
"arguments([arg('x')], kwonlyargs=[arg('y')])",
)
node = ast.arguments(args=[ast.arg("x")], kwonlyargs=[ast.arg('y')])
self.assertEqual(ast.dump(node, annotate_fields=False),
"arguments([], [arg('x')], kwonlyargs=[arg('y')])",
)
node = ast.arguments()
self.assertEqual(ast.dump(node, annotate_fields=False),
"arguments()",
)
# Classes:
node = ast.ClassDef(
'T',
[],
[ast.keyword('a', ast.Constant(None))],
[],
[ast.Name('dataclass', ctx=ast.Load())],
)
self.assertEqual(ast.dump(node),
"ClassDef(name='T', keywords=[keyword(arg='a', value=Constant(value=None))], decorator_list=[Name(id='dataclass', ctx=Load())])",
)
self.assertEqual(ast.dump(node, annotate_fields=False),
"ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass', Load())])",
)
def test_dump_show_empty(self):
def check_node(node, empty, full, **kwargs):
with self.subTest(show_empty=False):
self.assertEqual(
ast.dump(node, show_empty=False, **kwargs),
empty,
)
with self.subTest(show_empty=True):
self.assertEqual(
ast.dump(node, show_empty=True, **kwargs),
full,
)
def check_text(code, empty, full, **kwargs):
check_node(ast.parse(code), empty, full, **kwargs)
check_node(
ast.arguments(),
empty="arguments()",
full="arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[])",
)
check_node(
# Corner case: there are no real `Name` instances with `id=''`:
ast.Name(id='', ctx=ast.Load()),
empty="Name(id='', ctx=Load())",
full="Name(id='', ctx=Load())",
)
check_node(
ast.MatchSingleton(value=None),
empty="MatchSingleton(value=None)",
full="MatchSingleton(value=None)",
)
check_node(
ast.Constant(value=None),
empty="Constant(value=None)",
full="Constant(value=None)",
)
check_node(
ast.Constant(value=''),
empty="Constant(value='')",
full="Constant(value='')",
)
check_text(
"def a(b: int = 0, *, c): ...",
empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))])])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))], decorator_list=[], type_params=[])], type_ignores=[])",
)
check_text(
"def a(b: int = 0, *, c): ...",
empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], decorator_list=[], type_params=[], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)], type_ignores=[])",
include_attributes=True,
)
check_text(
'spam(eggs, "and cheese")',
empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')], keywords=[]))], type_ignores=[])",
)
check_text(
'spam(eggs, text="and cheese")',
empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))], type_ignores=[])",
)
check_text(
"import _ast as ast; from module import sub",
empty="Module(body=[Import(names=[alias(name='_ast', asname='ast')]), ImportFrom(module='module', names=[alias(name='sub')], level=0)])",
full="Module(body=[Import(names=[alias(name='_ast', asname='ast')]), ImportFrom(module='module', names=[alias(name='sub')], level=0)], type_ignores=[])",
)
def test_copy_location(self):
src = ast.parse('1 + 1', mode='eval')
src.body.right = ast.copy_location(ast.Constant(2), src.body.right)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=1, col_offset=0, '
'end_lineno=1, end_col_offset=1), op=Add(), right=Constant(value=2, '
'lineno=1, col_offset=4, end_lineno=1, end_col_offset=5), lineno=1, '
'col_offset=0, end_lineno=1, end_col_offset=5))'
)
func = ast.Name('spam', ast.Load())
src = ast.Call(col_offset=1, lineno=1, end_lineno=1, end_col_offset=1, func=func)
new = ast.copy_location(src, ast.Call(col_offset=None, lineno=None, func=func))
self.assertIsNone(new.end_lineno)
self.assertIsNone(new.end_col_offset)
self.assertEqual(new.lineno, 1)
self.assertEqual(new.col_offset, 1)
def test_fix_missing_locations(self):
src = ast.parse('write("spam")')
src.body.append(ast.Expr(ast.Call(ast.Name('spam', ast.Load()),
[ast.Constant('eggs')], [])))
self.assertEqual(src, ast.fix_missing_locations(src))
self.maxDiff = None
self.assertEqual(ast.dump(src, include_attributes=True),
"Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=5), "
"args=[Constant(value='spam', lineno=1, col_offset=6, end_lineno=1, "
"end_col_offset=12)], lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=13), lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=13), Expr(value=Call(func=Name(id='spam', ctx=Load(), "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=0), "
"args=[Constant(value='eggs', lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=0)], lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=0), lineno=1, col_offset=0, end_lineno=1, end_col_offset=0)])"
)
def test_increment_lineno(self):
src = ast.parse('1 + 1', mode='eval')
self.assertEqual(ast.increment_lineno(src, n=3), src)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=4, col_offset=0, '
'end_lineno=4, end_col_offset=1), op=Add(), right=Constant(value=1, '
'lineno=4, col_offset=4, end_lineno=4, end_col_offset=5), lineno=4, '
'col_offset=0, end_lineno=4, end_col_offset=5))'
)
# issue10869: do not increment lineno of root twice
src = ast.parse('1 + 1', mode='eval')
self.assertEqual(ast.increment_lineno(src.body, n=3), src.body)
self.assertEqual(ast.dump(src, include_attributes=True),
'Expression(body=BinOp(left=Constant(value=1, lineno=4, col_offset=0, '
'end_lineno=4, end_col_offset=1), op=Add(), right=Constant(value=1, '
'lineno=4, col_offset=4, end_lineno=4, end_col_offset=5), lineno=4, '
'col_offset=0, end_lineno=4, end_col_offset=5))'
)
src = ast.Call(
func=ast.Name("test", ast.Load()), args=[], keywords=[], lineno=1
)
self.assertEqual(ast.increment_lineno(src).lineno, 2)
self.assertIsNone(ast.increment_lineno(src).end_lineno)
def test_increment_lineno_on_module(self):
src = ast.parse(dedent("""\
a = 1
b = 2 # type: ignore
c = 3
d = 4 # type: ignore@tag
"""), type_comments=True)
ast.increment_lineno(src, n=5)
self.assertEqual(src.type_ignores[0].lineno, 7)
self.assertEqual(src.type_ignores[1].lineno, 9)
self.assertEqual(src.type_ignores[1].tag, '@tag')
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'args': []})
def test_iter_child_nodes(self):
node = ast.parse("spam(23, 42, eggs='leek')", mode='eval')
self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4)
iterator = ast.iter_child_nodes(node.body)
self.assertEqual(next(iterator).id, 'spam')
self.assertEqual(next(iterator).value, 23)
self.assertEqual(next(iterator).value, 42)
self.assertEqual(ast.dump(next(iterator)),
"keyword(arg='eggs', value=Constant(value='leek'))"
)
def test_get_docstring(self):
node = ast.parse('"""line one\n line two"""')
self.assertEqual(ast.get_docstring(node),
'line one\nline two')
node = ast.parse('class foo:\n """line one\n line two"""')
self.assertEqual(ast.get_docstring(node.body[0]),
'line one\nline two')
node = ast.parse('def foo():\n """line one\n line two"""')
self.assertEqual(ast.get_docstring(node.body[0]),
'line one\nline two')
node = ast.parse('async def foo():\n """spam\n ham"""')
self.assertEqual(ast.get_docstring(node.body[0]), 'spam\nham')
node = ast.parse('async def foo():\n """spam\n ham"""')
self.assertEqual(ast.get_docstring(node.body[0], clean=False), 'spam\n ham')
node = ast.parse('x')
self.assertRaises(TypeError, ast.get_docstring, node.body[0])
def test_get_docstring_none(self):
self.assertIsNone(ast.get_docstring(ast.parse('')))
node = ast.parse('x = "not docstring"')
self.assertIsNone(ast.get_docstring(node))
node = ast.parse('def foo():\n pass')
self.assertIsNone(ast.get_docstring(node))
node = ast.parse('class foo:\n pass')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('class foo:\n x = "not docstring"')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('class foo:\n def bar(self): pass')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('def foo():\n pass')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('def foo():\n x = "not docstring"')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('async def foo():\n pass')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('async def foo():\n x = "not docstring"')
self.assertIsNone(ast.get_docstring(node.body[0]))
node = ast.parse('async def foo():\n 42')
self.assertIsNone(ast.get_docstring(node.body[0]))
def test_multi_line_docstring_col_offset_and_lineno_issue16806(self):
node = ast.parse(
'"""line one\nline two"""\n\n'
'def foo():\n """line one\n line two"""\n\n'
' def bar():\n """line one\n line two"""\n'
' """line one\n line two"""\n'
'"""line one\nline two"""\n\n'
)
self.assertEqual(node.body[0].col_offset, 0)
self.assertEqual(node.body[0].lineno, 1)
self.assertEqual(node.body[1].body[0].col_offset, 2)
self.assertEqual(node.body[1].body[0].lineno, 5)
self.assertEqual(node.body[1].body[1].body[0].col_offset, 4)
self.assertEqual(node.body[1].body[1].body[0].lineno, 9)
self.assertEqual(node.body[1].body[2].col_offset, 2)
self.assertEqual(node.body[1].body[2].lineno, 11)
self.assertEqual(node.body[2].col_offset, 0)
self.assertEqual(node.body[2].lineno, 13)
def test_elif_stmt_start_position(self):
node = ast.parse('if a:\n pass\nelif b:\n pass\n')
elif_stmt = node.body[0].orelse[0]
self.assertEqual(elif_stmt.lineno, 3)
self.assertEqual(elif_stmt.col_offset, 0)
def test_elif_stmt_start_position_with_else(self):
node = ast.parse('if a:\n pass\nelif b:\n pass\nelse:\n pass\n')
elif_stmt = node.body[0].orelse[0]
self.assertEqual(elif_stmt.lineno, 3)
self.assertEqual(elif_stmt.col_offset, 0)
def test_starred_expr_end_position_within_call(self):
node = ast.parse('f(*[0, 1])')
starred_expr = node.body[0].value.args[0]
self.assertEqual(starred_expr.end_lineno, 1)
self.assertEqual(starred_expr.end_col_offset, 9)
def test_literal_eval(self):
self.assertEqual(ast.literal_eval('[1, 2, 3]'), [1, 2, 3])
self.assertEqual(ast.literal_eval('{"foo": 42}'), {"foo": 42})
self.assertEqual(ast.literal_eval('(True, False, None)'), (True, False, None))
self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3})
self.assertEqual(ast.literal_eval('b"hi"'), b"hi")
self.assertEqual(ast.literal_eval('set()'), set())
self.assertRaises(ValueError, ast.literal_eval, 'foo()')
self.assertEqual(ast.literal_eval('6'), 6)
self.assertEqual(ast.literal_eval('+6'), 6)
self.assertEqual(ast.literal_eval('-6'), -6)
self.assertEqual(ast.literal_eval('3.25'), 3.25)
self.assertEqual(ast.literal_eval('+3.25'), 3.25)
self.assertEqual(ast.literal_eval('-3.25'), -3.25)
self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0')
self.assertRaises(ValueError, ast.literal_eval, '++6')
self.assertRaises(ValueError, ast.literal_eval, '+True')
self.assertRaises(ValueError, ast.literal_eval, '2+3')
def test_literal_eval_str_int_limit(self):
with support.adjust_int_max_str_digits(4000):
ast.literal_eval('3'*4000) # no error
with self.assertRaises(SyntaxError) as err_ctx:
ast.literal_eval('3'*4001)
self.assertIn('Exceeds the limit ', str(err_ctx.exception))
self.assertIn(' Consider hexadecimal ', str(err_ctx.exception))
def test_literal_eval_complex(self):
# Issue #4907
self.assertEqual(ast.literal_eval('6j'), 6j)
self.assertEqual(ast.literal_eval('-6j'), -6j)
self.assertEqual(ast.literal_eval('6.75j'), 6.75j)
self.assertEqual(ast.literal_eval('-6.75j'), -6.75j)
self.assertEqual(ast.literal_eval('3+6j'), 3+6j)
self.assertEqual(ast.literal_eval('-3+6j'), -3+6j)
self.assertEqual(ast.literal_eval('3-6j'), 3-6j)
self.assertEqual(ast.literal_eval('-3-6j'), -3-6j)
self.assertEqual(ast.literal_eval('3.25+6.75j'), 3.25+6.75j)
self.assertEqual(ast.literal_eval('-3.25+6.75j'), -3.25+6.75j)
self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j)
self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j)
self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j)
self.assertRaises(ValueError, ast.literal_eval, '-6j+3')
self.assertRaises(ValueError, ast.literal_eval, '-6j+3j')
self.assertRaises(ValueError, ast.literal_eval, '3+-6j')
self.assertRaises(ValueError, ast.literal_eval, '3+(0+6j)')
self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)')
def test_literal_eval_malformed_dict_nodes(self):
malformed = ast.Dict(keys=[ast.Constant(1), ast.Constant(2)], values=[ast.Constant(3)])
self.assertRaises(ValueError, ast.literal_eval, malformed)
malformed = ast.Dict(keys=[ast.Constant(1)], values=[ast.Constant(2), ast.Constant(3)])
self.assertRaises(ValueError, ast.literal_eval, malformed)
def test_literal_eval_trailing_ws(self):
self.assertEqual(ast.literal_eval(" -1"), -1)
self.assertEqual(ast.literal_eval("\t\t-1"), -1)
self.assertEqual(ast.literal_eval(" \t -1"), -1)
self.assertRaises(IndentationError, ast.literal_eval, "\n -1")
def test_literal_eval_malformed_lineno(self):
msg = r'malformed node or string on line 3:'
with self.assertRaisesRegex(ValueError, msg):
ast.literal_eval("{'a': 1,\n'b':2,\n'c':++3,\n'd':4}")
node = ast.UnaryOp(
ast.UAdd(), ast.UnaryOp(ast.UAdd(), ast.Constant(6)))
self.assertIsNone(getattr(node, 'lineno', None))
msg = r'malformed node or string:'
with self.assertRaisesRegex(ValueError, msg):
ast.literal_eval(node)
def test_literal_eval_syntax_errors(self):
with self.assertRaisesRegex(SyntaxError, "unexpected indent"):
ast.literal_eval(r'''
\
(\
\ ''')
def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep')],
level=None,
lineno=None, col_offset=None)]
mod = ast.Module(body, [])
with self.assertRaises(ValueError) as cm:
compile(mod, 'test', 'exec')
self.assertIn("invalid integer value: None", str(cm.exception))
def test_level_as_none(self):
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep',
lineno=0, col_offset=0)],
level=None,
lineno=0, col_offset=0)]
mod = ast.Module(body, [])
code = compile(mod, 'test', 'exec')
ns = {}
exec(code, ns)
self.assertIn('sleep', ns)
def test_recursion_direct(self):
e = ast.UnaryOp(op=ast.Not(), lineno=0, col_offset=0, operand=ast.Constant(1))
e.operand = e
with self.assertRaises(RecursionError):
with support.infinite_recursion():
compile(ast.Expression(e), "<test>", "eval")
def test_recursion_indirect(self):
e = ast.UnaryOp(op=ast.Not(), lineno=0, col_offset=0, operand=ast.Constant(1))
f = ast.UnaryOp(op=ast.Not(), lineno=0, col_offset=0, operand=ast.Constant(1))
e.operand = f
f.operand = e
with self.assertRaises(RecursionError):
with support.infinite_recursion():
compile(ast.Expression(e), "<test>", "eval")
class ASTValidatorTests(unittest.TestCase):
def mod(self, mod, msg=None, mode="exec", *, exc=ValueError):
mod.lineno = mod.col_offset = 0
ast.fix_missing_locations(mod)
if msg is None:
compile(mod, "<test>", mode)
else:
with self.assertRaises(exc) as cm:
compile(mod, "<test>", mode)
self.assertIn(msg, str(cm.exception))
def expr(self, node, msg=None, *, exc=ValueError):
mod = ast.Module([ast.Expr(node)], [])
self.mod(mod, msg, exc=exc)
def stmt(self, stmt, msg=None):
mod = ast.Module([stmt], [])
self.mod(mod, msg)
def test_module(self):
m = ast.Interactive([ast.Expr(ast.Name("x", ast.Store()))])
self.mod(m, "must have Load context", "single")
m = ast.Expression(ast.Name("x", ast.Store()))
self.mod(m, "must have Load context", "eval")
def _check_arguments(self, fac, check):
def arguments(args=None, posonlyargs=None, vararg=None,
kwonlyargs=None, kwarg=None,
defaults=None, kw_defaults=None):
if args is None:
args = []
if posonlyargs is None:
posonlyargs = []
if kwonlyargs is None:
kwonlyargs = []
if defaults is None:
defaults = []
if kw_defaults is None:
kw_defaults = []
args = ast.arguments(args, posonlyargs, vararg, kwonlyargs,
kw_defaults, kwarg, defaults)
return fac(args)
args = [ast.arg("x", ast.Name("x", ast.Store()))]
check(arguments(args=args), "must have Load context")
check(arguments(posonlyargs=args), "must have Load context")
check(arguments(kwonlyargs=args), "must have Load context")
check(arguments(defaults=[ast.Constant(3)]),
"more positional defaults than args")
check(arguments(kw_defaults=[ast.Constant(4)]),
"length of kwonlyargs is not the same as kw_defaults")
args = [ast.arg("x", ast.Name("x", ast.Load()))]
check(arguments(args=args, defaults=[ast.Name("x", ast.Store())]),
"must have Load context")
args = [ast.arg("a", ast.Name("x", ast.Load())),
ast.arg("b", ast.Name("y", ast.Load()))]
check(arguments(kwonlyargs=args,
kw_defaults=[None, ast.Name("x", ast.Store())]),
"must have Load context")
def test_funcdef(self):
a = ast.arguments([], [], None, [], [], None, [])
f = ast.FunctionDef("x", a, [], [], None, None, [])
self.stmt(f, "empty body on FunctionDef")
f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())], None, None, [])
self.stmt(f, "must have Load context")
f = ast.FunctionDef("x", a, [ast.Pass()], [],
ast.Name("x", ast.Store()), None, [])
self.stmt(f, "must have Load context")
f = ast.FunctionDef("x", ast.arguments(), [ast.Pass()])
self.stmt(f)
def fac(args):
return ast.FunctionDef("x", args, [ast.Pass()], [], None, None, [])
self._check_arguments(fac, self.stmt)
def test_funcdef_pattern_matching(self):
# gh-104799: New fields on FunctionDef should be added at the end
def matcher(node):
match node:
case ast.FunctionDef("foo", ast.arguments(args=[ast.arg("bar")]),
[ast.Pass()],
[ast.Name("capybara", ast.Load())],
ast.Name("pacarana", ast.Load())):
return True
case _:
return False
code = """
@capybara
def foo(bar) -> pacarana:
pass
"""
source = ast.parse(textwrap.dedent(code))
funcdef = source.body[0]
self.assertIsInstance(funcdef, ast.FunctionDef)
self.assertTrue(matcher(funcdef))
def test_classdef(self):
def cls(bases=None, keywords=None, body=None, decorator_list=None, type_params=None):
if bases is None:
bases = []
if keywords is None:
keywords = []
if body is None:
body = [ast.Pass()]
if decorator_list is None:
decorator_list = []
if type_params is None:
type_params = []
return ast.ClassDef("myclass", bases, keywords,
body, decorator_list, type_params)
self.stmt(cls(bases=[ast.Name("x", ast.Store())]),
"must have Load context")
self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store()))]),
"must have Load context")
self.stmt(cls(body=[]), "empty body on ClassDef")
self.stmt(cls(body=[None]), "None disallowed")
self.stmt(cls(decorator_list=[ast.Name("x", ast.Store())]),
"must have Load context")
def test_delete(self):
self.stmt(ast.Delete([]), "empty targets on Delete")
self.stmt(ast.Delete([None]), "None disallowed")
self.stmt(ast.Delete([ast.Name("x", ast.Load())]),
"must have Del context")
def test_assign(self):
self.stmt(ast.Assign([], ast.Constant(3)), "empty targets on Assign")
self.stmt(ast.Assign([None], ast.Constant(3)), "None disallowed")
self.stmt(ast.Assign([ast.Name("x", ast.Load())], ast.Constant(3)),
"must have Store context")
self.stmt(ast.Assign([ast.Name("x", ast.Store())],
ast.Name("y", ast.Store())),
"must have Load context")
def test_augassign(self):
aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
ast.Name("y", ast.Load()))
self.stmt(aug, "must have Store context")
aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
ast.Name("y", ast.Store()))
self.stmt(aug, "must have Load context")
def test_for(self):
x = ast.Name("x", ast.Store())
y = ast.Name("y", ast.Load())
p = ast.Pass()
self.stmt(ast.For(x, y, [], []), "empty body on For")
self.stmt(ast.For(ast.Name("x", ast.Load()), y, [p], []),
"must have Store context")
self.stmt(ast.For(x, ast.Name("y", ast.Store()), [p], []),
"must have Load context")
e = ast.Expr(ast.Name("x", ast.Store()))
self.stmt(ast.For(x, y, [e], []), "must have Load context")
self.stmt(ast.For(x, y, [p], [e]), "must have Load context")
def test_while(self):
self.stmt(ast.While(ast.Constant(3), [], []), "empty body on While")
self.stmt(ast.While(ast.Name("x", ast.Store()), [ast.Pass()], []),
"must have Load context")
self.stmt(ast.While(ast.Constant(3), [ast.Pass()],
[ast.Expr(ast.Name("x", ast.Store()))]),
"must have Load context")
def test_if(self):
self.stmt(ast.If(ast.Constant(3), [], []), "empty body on If")
i = ast.If(ast.Name("x", ast.Store()), [ast.Pass()], [])
self.stmt(i, "must have Load context")
i = ast.If(ast.Constant(3), [ast.Expr(ast.Name("x", ast.Store()))], [])
self.stmt(i, "must have Load context")
i = ast.If(ast.Constant(3), [ast.Pass()],
[ast.Expr(ast.Name("x", ast.Store()))])
self.stmt(i, "must have Load context")
def test_with(self):
p = ast.Pass()
self.stmt(ast.With([], [p]), "empty items on With")
i = ast.withitem(ast.Constant(3), None)
self.stmt(ast.With([i], []), "empty body on With")
i = ast.withitem(ast.Name("x", ast.Store()), None)
self.stmt(ast.With([i], [p]), "must have Load context")
i = ast.withitem(ast.Constant(3), ast.Name("x", ast.Load()))
self.stmt(ast.With([i], [p]), "must have Store context")
def test_raise(self):
r = ast.Raise(None, ast.Constant(3))
self.stmt(r, "Raise with cause but no exception")
r = ast.Raise(ast.Name("x", ast.Store()), None)
self.stmt(r, "must have Load context")
r = ast.Raise(ast.Constant(4), ast.Name("x", ast.Store()))
self.stmt(r, "must have Load context")
def test_try(self):
p = ast.Pass()
t = ast.Try([], [], [], [p])
self.stmt(t, "empty body on Try")
t = ast.Try([ast.Expr(ast.Name("x", ast.Store()))], [], [], [p])
self.stmt(t, "must have Load context")
t = ast.Try([p], [], [], [])
self.stmt(t, "Try has neither except handlers nor finalbody")
t = ast.Try([p], [], [p], [p])
self.stmt(t, "Try has orelse but no except handlers")
t = ast.Try([p], [ast.ExceptHandler(None, "x", [])], [], [])
self.stmt(t, "empty body on ExceptHandler")
e = [ast.ExceptHandler(ast.Name("x", ast.Store()), "y", [p])]
self.stmt(ast.Try([p], e, [], []), "must have Load context")
e = [ast.ExceptHandler(None, "x", [p])]
t = ast.Try([p], e, [ast.Expr(ast.Name("x", ast.Store()))], [p])
self.stmt(t, "must have Load context")
t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))])
self.stmt(t, "must have Load context")
def test_try_star(self):
p = ast.Pass()
t = ast.TryStar([], [], [], [p])
self.stmt(t, "empty body on TryStar")
t = ast.TryStar([ast.Expr(ast.Name("x", ast.Store()))], [], [], [p])
self.stmt(t, "must have Load context")
t = ast.TryStar([p], [], [], [])
self.stmt(t, "TryStar has neither except handlers nor finalbody")
t = ast.TryStar([p], [], [p], [p])
self.stmt(t, "TryStar has orelse but no except handlers")
t = ast.TryStar([p], [ast.ExceptHandler(None, "x", [])], [], [])
self.stmt(t, "empty body on ExceptHandler")
e = [ast.ExceptHandler(ast.Name("x", ast.Store()), "y", [p])]
self.stmt(ast.TryStar([p], e, [], []), "must have Load context")
e = [ast.ExceptHandler(None, "x", [p])]
t = ast.TryStar([p], e, [ast.Expr(ast.Name("x", ast.Store()))], [p])
self.stmt(t, "must have Load context")
t = ast.TryStar([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))])
self.stmt(t, "must have Load context")
def test_assert(self):
self.stmt(ast.Assert(ast.Name("x", ast.Store()), None),
"must have Load context")
assrt = ast.Assert(ast.Name("x", ast.Load()),
ast.Name("y", ast.Store()))
self.stmt(assrt, "must have Load context")
def test_import(self):
self.stmt(ast.Import([]), "empty names on Import")
def test_importfrom(self):
imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
self.stmt(imp, "Negative ImportFrom level")
self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
def test_global(self):
self.stmt(ast.Global([]), "empty names on Global")
def test_nonlocal(self):
self.stmt(ast.Nonlocal([]), "empty names on Nonlocal")
def test_expr(self):
e = ast.Expr(ast.Name("x", ast.Store()))
self.stmt(e, "must have Load context")
def test_boolop(self):
b = ast.BoolOp(ast.And(), [])
self.expr(b, "less than 2 values")
b = ast.BoolOp(ast.And(), [ast.Constant(3)])
self.expr(b, "less than 2 values")
b = ast.BoolOp(ast.And(), [ast.Constant(4), None])
self.expr(b, "None disallowed")
b = ast.BoolOp(ast.And(), [ast.Constant(4), ast.Name("x", ast.Store())])
self.expr(b, "must have Load context")
def test_unaryop(self):
u = ast.UnaryOp(ast.Not(), ast.Name("x", ast.Store()))
self.expr(u, "must have Load context")
def test_lambda(self):
a = ast.arguments([], [], None, [], [], None, [])
self.expr(ast.Lambda(a, ast.Name("x", ast.Store())),
"must have Load context")
def fac(args):
return ast.Lambda(args, ast.Name("x", ast.Load()))
self._check_arguments(fac, self.expr)
def test_ifexp(self):
l = ast.Name("x", ast.Load())
s = ast.Name("y", ast.Store())
for args in (s, l, l), (l, s, l), (l, l, s):
self.expr(ast.IfExp(*args), "must have Load context")
def test_dict(self):
d = ast.Dict([], [ast.Name("x", ast.Load())])
self.expr(d, "same number of keys as values")
d = ast.Dict([ast.Name("x", ast.Load())], [None])
self.expr(d, "None disallowed")
def test_set(self):
self.expr(ast.Set([None]), "None disallowed")
s = ast.Set([ast.Name("x", ast.Store())])
self.expr(s, "must have Load context")
def _check_comprehension(self, fac):
self.expr(fac([]), "comprehension with no generators")
g = ast.comprehension(ast.Name("x", ast.Load()),
ast.Name("x", ast.Load()), [], 0)
self.expr(fac([g]), "must have Store context")
g = ast.comprehension(ast.Name("x", ast.Store()),
ast.Name("x", ast.Store()), [], 0)
self.expr(fac([g]), "must have Load context")
x = ast.Name("x", ast.Store())
y = ast.Name("y", ast.Load())
g = ast.comprehension(x, y, [None], 0)
self.expr(fac([g]), "None disallowed")
g = ast.comprehension(x, y, [ast.Name("x", ast.Store())], 0)
self.expr(fac([g]), "must have Load context")
def _simple_comp(self, fac):
g = ast.comprehension(ast.Name("x", ast.Store()),
ast.Name("x", ast.Load()), [], 0)
self.expr(fac(ast.Name("x", ast.Store()), [g]),
"must have Load context")
def wrap(gens):
return fac(ast.Name("x", ast.Store()), gens)
self._check_comprehension(wrap)
def test_listcomp(self):
self._simple_comp(ast.ListComp)
def test_setcomp(self):
self._simple_comp(ast.SetComp)
def test_generatorexp(self):
self._simple_comp(ast.GeneratorExp)
def test_dictcomp(self):
g = ast.comprehension(ast.Name("y", ast.Store()),
ast.Name("p", ast.Load()), [], 0)
c = ast.DictComp(ast.Name("x", ast.Store()),
ast.Name("y", ast.Load()), [g])
self.expr(c, "must have Load context")
c = ast.DictComp(ast.Name("x", ast.Load()),
ast.Name("y", ast.Store()), [g])
self.expr(c, "must have Load context")
def factory(comps):
k = ast.Name("x", ast.Load())
v = ast.Name("y", ast.Load())
return ast.DictComp(k, v, comps)
self._check_comprehension(factory)
def test_yield(self):
self.expr(ast.Yield(ast.Name("x", ast.Store())), "must have Load")
self.expr(ast.YieldFrom(ast.Name("x", ast.Store())), "must have Load")
def test_compare(self):
left = ast.Name("x", ast.Load())
comp = ast.Compare(left, [ast.In()], [])
self.expr(comp, "no comparators")
comp = ast.Compare(left, [ast.In()], [ast.Constant(4), ast.Constant(5)])
self.expr(comp, "different number of comparators and operands")
comp = ast.Compare(ast.Constant("blah"), [ast.In()], [left])
self.expr(comp)
comp = ast.Compare(left, [ast.In()], [ast.Constant("blah")])
self.expr(comp)
def test_call(self):
func = ast.Name("x", ast.Load())
args = [ast.Name("y", ast.Load())]
keywords = [ast.keyword("w", ast.Name("z", ast.Load()))]
call = ast.Call(ast.Name("x", ast.Store()), args, keywords)
self.expr(call, "must have Load context")
call = ast.Call(func, [None], keywords)
self.expr(call, "None disallowed")
bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store()))]
call = ast.Call(func, args, bad_keywords)
self.expr(call, "must have Load context")
def test_attribute(self):
attr = ast.Attribute(ast.Name("x", ast.Store()), "y", ast.Load())
self.expr(attr, "must have Load context")
def test_subscript(self):
sub = ast.Subscript(ast.Name("x", ast.Store()), ast.Constant(3),
ast.Load())
self.expr(sub, "must have Load context")
x = ast.Name("x", ast.Load())
sub = ast.Subscript(x, ast.Name("y", ast.Store()),
ast.Load())
self.expr(sub, "must have Load context")
s = ast.Name("x", ast.Store())
for args in (s, None, None), (None, s, None), (None, None, s):
sl = ast.Slice(*args)
self.expr(ast.Subscript(x, sl, ast.Load()),
"must have Load context")
sl = ast.Tuple([], ast.Load())
self.expr(ast.Subscript(x, sl, ast.Load()))
sl = ast.Tuple([s], ast.Load())
self.expr(ast.Subscript(x, sl, ast.Load()), "must have Load context")
def test_starred(self):
left = ast.List([ast.Starred(ast.Name("x", ast.Load()), ast.Store())],
ast.Store())
assign = ast.Assign([left], ast.Constant(4))
self.stmt(assign, "must have Store context")
def _sequence(self, fac):
self.expr(fac([None], ast.Load()), "None disallowed")
self.expr(fac([ast.Name("x", ast.Store())], ast.Load()),
"must have Load context")
def test_list(self):
self._sequence(ast.List)
def test_tuple(self):
self._sequence(ast.Tuple)
@support.requires_resource('cpu')
def test_stdlib_validates(self):
for module in STDLIB_FILES:
with self.subTest(module):
fn = os.path.join(STDLIB, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
mod = ast.parse(source, fn)
compile(mod, fn, "exec")
mod2 = ast.parse(source, fn)
self.assertTrue(ast.compare(mod, mod2))
constant_1 = ast.Constant(1)
pattern_1 = ast.MatchValue(constant_1)
constant_x = ast.Constant('x')
pattern_x = ast.MatchValue(constant_x)
constant_true = ast.Constant(True)
pattern_true = ast.MatchSingleton(True)
name_carter = ast.Name('carter', ast.Load())
_MATCH_PATTERNS = [
ast.MatchValue(
ast.Attribute(
ast.Attribute(
ast.Name('x', ast.Store()),
'y', ast.Load()
),
'z', ast.Load()
)
),
ast.MatchValue(
ast.Attribute(
ast.Attribute(
ast.Name('x', ast.Load()),
'y', ast.Store()
),
'z', ast.Load()
)
),
ast.MatchValue(
ast.Constant(...)
),
ast.MatchValue(
ast.Constant(True)
),
ast.MatchValue(
ast.Constant((1,2,3))
),
ast.MatchSingleton('string'),
ast.MatchSequence([
ast.MatchSingleton('string')
]),
ast.MatchSequence(
[
ast.MatchSequence(
[
ast.MatchSingleton('string')
]
)
]
),
ast.MatchMapping(
[constant_1, constant_true],
[pattern_x]
),
ast.MatchMapping(
[constant_true, constant_1],
[pattern_x, pattern_1],
rest='True'
),
ast.MatchMapping(
[constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
[pattern_x, pattern_1],
rest='legit'
),
ast.MatchClass(
ast.Attribute(
ast.Attribute(
constant_x,
'y', ast.Load()),
'z', ast.Load()),
patterns=[], kwd_attrs=[], kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=['True'],
kwd_patterns=[pattern_1]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=[],
kwd_patterns=[pattern_1]
),
ast.MatchClass(
name_carter,
patterns=[ast.MatchSingleton('string')],
kwd_attrs=[],
kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[ast.MatchStar()],
kwd_attrs=[],
kwd_patterns=[]
),
ast.MatchClass(
name_carter,
patterns=[],
kwd_attrs=[],
kwd_patterns=[ast.MatchStar()]
),
ast.MatchClass(
constant_true, # invalid name
patterns=[],
kwd_attrs=['True'],
kwd_patterns=[pattern_1]
),
ast.MatchSequence(
[
ast.MatchStar("True")
]
),
ast.MatchAs(
name='False'
),
ast.MatchOr(
[]
),
ast.MatchOr(
[pattern_1]
),
ast.MatchOr(
[pattern_1, pattern_x, ast.MatchSingleton('xxx')]
),
ast.MatchAs(name="_"),
ast.MatchStar(name="x"),
ast.MatchSequence([ast.MatchStar("_")]),
ast.MatchMapping([], [], rest="_"),
]
def test_match_validation_pattern(self):
name_x = ast.Name('x', ast.Load())
for pattern in self._MATCH_PATTERNS:
with self.subTest(ast.dump(pattern, indent=4)):
node = ast.Match(
subject=name_x,
cases = [
ast.match_case(
pattern=pattern,
body = [ast.Pass()]
)
]
)
node = ast.fix_missing_locations(node)
module = ast.Module([node], [])
with self.assertRaises(ValueError):
compile(module, "<test>", "exec")
class ConstantTests(unittest.TestCase):
"""Tests on the ast.Constant node type."""
def compile_constant(self, value):
tree = ast.parse("x = 123")
node = tree.body[0].value
new_node = ast.Constant(value=value)
ast.copy_location(new_node, node)
tree.body[0].value = new_node
code = compile(tree, "<string>", "exec")
ns = {}
exec(code, ns)
return ns['x']
def test_validation(self):
with self.assertRaises(TypeError) as cm:
self.compile_constant([1, 2, 3])
self.assertEqual(str(cm.exception),
"got an invalid type in Constant: list")
def test_singletons(self):
for const in (None, False, True, Ellipsis, b'', frozenset()):
with self.subTest(const=const):
value = self.compile_constant(const)
self.assertIs(value, const)
def test_values(self):
nested_tuple = (1,)
nested_frozenset = frozenset({1})
for level in range(3):
nested_tuple = (nested_tuple, 2)
nested_frozenset = frozenset({nested_frozenset, 2})
values = (123, 123.0, 123j,
"unicode", b'bytes',
tuple("tuple"), frozenset("frozenset"),
nested_tuple, nested_frozenset)
for value in values:
with self.subTest(value=value):
result = self.compile_constant(value)
self.assertEqual(result, value)
def test_assign_to_constant(self):
tree = ast.parse("x = 1")
target = tree.body[0].targets[0]
new_target = ast.Constant(value=1)
ast.copy_location(new_target, target)
tree.body[0].targets[0] = new_target
with self.assertRaises(ValueError) as cm:
compile(tree, "string", "exec")
self.assertEqual(str(cm.exception),
"expression which can't be assigned "
"to in Store context")
def test_get_docstring(self):
tree = ast.parse("'docstring'\nx = 1")
self.assertEqual(ast.get_docstring(tree), 'docstring')
def get_load_const(self, tree):
# Compile to bytecode, disassemble and get parameter of LOAD_CONST
# instructions
co = compile(tree, '<string>', 'exec')
consts = []
for instr in dis.get_instructions(co):
if instr.opname == 'LOAD_CONST' or instr.opname == 'RETURN_CONST':
consts.append(instr.argval)
return consts
@support.cpython_only
def test_load_const(self):
consts = [None,
True, False,
124,
2.0,
3j,
"unicode",
b'bytes',
(1, 2, 3)]
code = '\n'.join(['x={!r}'.format(const) for const in consts])
code += '\nx = ...'
consts.extend((Ellipsis, None))
tree = ast.parse(code)
self.assertEqual(self.get_load_const(tree),
consts)
# Replace expression nodes with constants
for assign, const in zip(tree.body, consts):
assert isinstance(assign, ast.Assign), ast.dump(assign)
new_node = ast.Constant(value=const)
ast.copy_location(new_node, assign.value)
assign.value = new_node
self.assertEqual(self.get_load_const(tree),
consts)
def test_literal_eval(self):
tree = ast.parse("1 + 2")
binop = tree.body[0].value
new_left = ast.Constant(value=10)
ast.copy_location(new_left, binop.left)
binop.left = new_left
new_right = ast.Constant(value=20j)
ast.copy_location(new_right, binop.right)
binop.right = new_right
self.assertEqual(ast.literal_eval(binop), 10+20j)
def test_string_kind(self):
c = ast.parse('"x"', mode='eval').body
self.assertEqual(c.value, "x")
self.assertEqual(c.kind, None)
c = ast.parse('u"x"', mode='eval').body
self.assertEqual(c.value, "x")
self.assertEqual(c.kind, "u")
c = ast.parse('r"x"', mode='eval').body
self.assertEqual(c.value, "x")
self.assertEqual(c.kind, None)
c = ast.parse('b"x"', mode='eval').body
self.assertEqual(c.value, b"x")
self.assertEqual(c.kind, None)
class EndPositionTests(unittest.TestCase):
"""Tests for end position of AST nodes.
Testing end positions of nodes requires a bit of extra care
because of how LL parsers work.
"""
def _check_end_pos(self, ast_node, end_lineno, end_col_offset):
self.assertEqual(ast_node.end_lineno, end_lineno)
self.assertEqual(ast_node.end_col_offset, end_col_offset)
def _check_content(self, source, ast_node, content):
self.assertEqual(ast.get_source_segment(source, ast_node), content)
def _parse_value(self, s):
# Use duck-typing to support both single expression
# and a right hand side of an assignment statement.
return ast.parse(s).body[0].value
def test_lambda(self):
s = 'lambda x, *y: None'
lam = self._parse_value(s)
self._check_content(s, lam.body, 'None')
self._check_content(s, lam.args.args[0], 'x')
self._check_content(s, lam.args.vararg, 'y')
def test_func_def(self):
s = dedent('''
def func(x: int,
*args: str,
z: float = 0,
**kwargs: Any) -> bool:
return True
''').strip()
fdef = ast.parse(s).body[0]
self._check_end_pos(fdef, 5, 15)
self._check_content(s, fdef.body[0], 'return True')
self._check_content(s, fdef.args.args[0], 'x: int')
self._check_content(s, fdef.args.args[0].annotation, 'int')
self._check_content(s, fdef.args.kwarg, 'kwargs: Any')
self._check_content(s, fdef.args.kwarg.annotation, 'Any')
def test_call(self):
s = 'func(x, y=2, **kw)'
call = self._parse_value(s)
self._check_content(s, call.func, 'func')
self._check_content(s, call.keywords[0].value, '2')
self._check_content(s, call.keywords[1].value, 'kw')
def test_call_noargs(self):
s = 'x[0]()'
call = self._parse_value(s)
self._check_content(s, call.func, 'x[0]')
self._check_end_pos(call, 1, 6)
def test_class_def(self):
s = dedent('''
class C(A, B):
x: int = 0
''').strip()
cdef = ast.parse(s).body[0]
self._check_end_pos(cdef, 2, 14)
self._check_content(s, cdef.bases[1], 'B')
self._check_content(s, cdef.body[0], 'x: int = 0')
def test_class_kw(self):
s = 'class S(metaclass=abc.ABCMeta): pass'
cdef = ast.parse(s).body[0]
self._check_content(s, cdef.keywords[0].value, 'abc.ABCMeta')
def test_multi_line_str(self):
s = dedent('''
x = """Some multi-line text.
It goes on starting from same indent."""
''').strip()
assign = ast.parse(s).body[0]
self._check_end_pos(assign, 3, 40)
self._check_end_pos(assign.value, 3, 40)
def test_continued_str(self):
s = dedent('''
x = "first part" \\
"second part"
''').strip()
assign = ast.parse(s).body[0]
self._check_end_pos(assign, 2, 13)
self._check_end_pos(assign.value, 2, 13)
def test_suites(self):
# We intentionally put these into the same string to check
# that empty lines are not part of the suite.
s = dedent('''
while True:
pass
if one():
x = None
elif other():
y = None
else:
z = None
for x, y in stuff:
assert True
try:
raise RuntimeError
except TypeError as e:
pass
pass
''').strip()
mod = ast.parse(s)
while_loop = mod.body[0]
if_stmt = mod.body[1]
for_loop = mod.body[2]
try_stmt = mod.body[3]
pass_stmt = mod.body[4]
self._check_end_pos(while_loop, 2, 8)
self._check_end_pos(if_stmt, 9, 12)
self._check_end_pos(for_loop, 12, 15)
self._check_end_pos(try_stmt, 17, 8)
self._check_end_pos(pass_stmt, 19, 4)
self._check_content(s, while_loop.test, 'True')
self._check_content(s, if_stmt.body[0], 'x = None')
self._check_content(s, if_stmt.orelse[0].test, 'other()')
self._check_content(s, for_loop.target, 'x, y')
self._check_content(s, try_stmt.body[0], 'raise RuntimeError')
self._check_content(s, try_stmt.handlers[0].type, 'TypeError')
def test_fstring(self):
s = 'x = f"abc {x + y} abc"'
fstr = self._parse_value(s)
binop = fstr.values[1].value
self._check_content(s, binop, 'x + y')
def test_fstring_multi_line(self):
s = dedent('''
f"""Some multi-line text.
{
arg_one
+
arg_two
}
It goes on..."""
''').strip()
fstr = self._parse_value(s)
binop = fstr.values[1].value
self._check_end_pos(binop, 5, 7)
self._check_content(s, binop.left, 'arg_one')
self._check_content(s, binop.right, 'arg_two')
def test_import_from_multi_line(self):
s = dedent('''
from x.y.z import (
a, b, c as c
)
''').strip()
imp = ast.parse(s).body[0]
self._check_end_pos(imp, 3, 1)
self._check_end_pos(imp.names[2], 2, 16)
def test_slices(self):
s1 = 'f()[1, 2] [0]'
s2 = 'x[ a.b: c.d]'
sm = dedent('''
x[ a.b: f () ,
g () : c.d
]
''').strip()
i1, i2, im = map(self._parse_value, (s1, s2, sm))
self._check_content(s1, i1.value, 'f()[1, 2]')
self._check_content(s1, i1.value.slice, '1, 2')
self._check_content(s2, i2.slice.lower, 'a.b')
self._check_content(s2, i2.slice.upper, 'c.d')
self._check_content(sm, im.slice.elts[0].upper, 'f ()')
self._check_content(sm, im.slice.elts[1].lower, 'g ()')
self._check_end_pos(im, 3, 3)
def test_binop(self):
s = dedent('''
(1 * 2 + (3 ) +
4
)
''').strip()
binop = self._parse_value(s)
self._check_end_pos(binop, 2, 6)
self._check_content(s, binop.right, '4')
self._check_content(s, binop.left, '1 * 2 + (3 )')
self._check_content(s, binop.left.right, '3')
def test_boolop(self):
s = dedent('''
if (one_condition and
(other_condition or yet_another_one)):
pass
''').strip()
bop = ast.parse(s).body[0].test
self._check_end_pos(bop, 2, 44)
self._check_content(s, bop.values[1],
'other_condition or yet_another_one')
def test_tuples(self):
s1 = 'x = () ;'
s2 = 'x = 1 , ;'
s3 = 'x = (1 , 2 ) ;'
sm = dedent('''
x = (
a, b,
)
''').strip()
t1, t2, t3, tm = map(self._parse_value, (s1, s2, s3, sm))
self._check_content(s1, t1, '()')
self._check_content(s2, t2, '1 ,')
self._check_content(s3, t3, '(1 , 2 )')
self._check_end_pos(tm, 3, 1)
def test_attribute_spaces(self):
s = 'func(x. y .z)'
call = self._parse_value(s)
self._check_content(s, call, s)
self._check_content(s, call.args[0], 'x. y .z')
def test_redundant_parenthesis(self):
s = '( ( ( a + b ) ) )'
v = ast.parse(s).body[0].value
self.assertEqual(type(v).__name__, 'BinOp')
self._check_content(s, v, 'a + b')
s2 = 'await ' + s
v = ast.parse(s2).body[0].value.value
self.assertEqual(type(v).__name__, 'BinOp')
self._check_content(s2, v, 'a + b')
def test_trailers_with_redundant_parenthesis(self):
tests = (
('( ( ( a ) ) ) ( )', 'Call'),
('( ( ( a ) ) ) ( b )', 'Call'),
('( ( ( a ) ) ) [ b ]', 'Subscript'),
('( ( ( a ) ) ) . b', 'Attribute'),
)
for s, t in tests:
with self.subTest(s):
v = ast.parse(s).body[0].value
self.assertEqual(type(v).__name__, t)
self._check_content(s, v, s)
s2 = 'await ' + s
v = ast.parse(s2).body[0].value.value
self.assertEqual(type(v).__name__, t)
self._check_content(s2, v, s)
def test_displays(self):
s1 = '[{}, {1, }, {1, 2,} ]'
s2 = '{a: b, f (): g () ,}'
c1 = self._parse_value(s1)
c2 = self._parse_value(s2)
self._check_content(s1, c1.elts[0], '{}')
self._check_content(s1, c1.elts[1], '{1, }')
self._check_content(s1, c1.elts[2], '{1, 2,}')
self._check_content(s2, c2.keys[1], 'f ()')
self._check_content(s2, c2.values[1], 'g ()')
def test_comprehensions(self):
s = dedent('''
x = [{x for x, y in stuff
if cond.x} for stuff in things]
''').strip()
cmp = self._parse_value(s)
self._check_end_pos(cmp, 2, 37)
self._check_content(s, cmp.generators[0].iter, 'things')
self._check_content(s, cmp.elt.generators[0].iter, 'stuff')
self._check_content(s, cmp.elt.generators[0].ifs[0], 'cond.x')
self._check_content(s, cmp.elt.generators[0].target, 'x, y')
def test_yield_await(self):
s = dedent('''
async def f():
yield x
await y
''').strip()
fdef = ast.parse(s).body[0]
self._check_content(s, fdef.body[0].value, 'yield x')
self._check_content(s, fdef.body[1].value, 'await y')
def test_source_segment_multi(self):
s_orig = dedent('''
x = (
a, b,
) + ()
''').strip()
s_tuple = dedent('''
(
a, b,
)
''').strip()
binop = self._parse_value(s_orig)
self.assertEqual(ast.get_source_segment(s_orig, binop.left), s_tuple)
def test_source_segment_padded(self):
s_orig = dedent('''
class C:
def fun(self) -> None:
"ЖЖЖЖЖ"
''').strip()
s_method = ' def fun(self) -> None:\n' \
' "ЖЖЖЖЖ"'
cdef = ast.parse(s_orig).body[0]
self.assertEqual(ast.get_source_segment(s_orig, cdef.body[0], padded=True),
s_method)
def test_source_segment_endings(self):
s = 'v = 1\r\nw = 1\nx = 1\n\ry = 1\rz = 1\r\n'
v, w, x, y, z = ast.parse(s).body
self._check_content(s, v, 'v = 1')
self._check_content(s, w, 'w = 1')
self._check_content(s, x, 'x = 1')
self._check_content(s, y, 'y = 1')
self._check_content(s, z, 'z = 1')
def test_source_segment_tabs(self):
s = dedent('''
class C:
\t\f def fun(self) -> None:
\t\f pass
''').strip()
s_method = ' \t\f def fun(self) -> None:\n' \
' \t\f pass'
cdef = ast.parse(s).body[0]
self.assertEqual(ast.get_source_segment(s, cdef.body[0], padded=True), s_method)
def test_source_segment_newlines(self):
s = 'def f():\n pass\ndef g():\r pass\r\ndef h():\r\n pass\r\n'
f, g, h = ast.parse(s).body
self._check_content(s, f, 'def f():\n pass')
self._check_content(s, g, 'def g():\r pass')
self._check_content(s, h, 'def h():\r\n pass')
s = 'def f():\n a = 1\r b = 2\r\n c = 3\n'
f = ast.parse(s).body[0]
self._check_content(s, f, s.rstrip())
def test_source_segment_missing_info(self):
s = 'v = 1\r\nw = 1\nx = 1\n\ry = 1\r\n'
v, w, x, y = ast.parse(s).body
del v.lineno
del w.end_lineno
del x.col_offset
del y.end_col_offset
self.assertIsNone(ast.get_source_segment(s, v))
self.assertIsNone(ast.get_source_segment(s, w))
self.assertIsNone(ast.get_source_segment(s, x))
self.assertIsNone(ast.get_source_segment(s, y))
class NodeTransformerTests(ASTTestMixin, unittest.TestCase):
def assertASTTransformation(self, transformer_class,
initial_code, expected_code):
initial_ast = ast.parse(dedent(initial_code))
expected_ast = ast.parse(dedent(expected_code))
transformer = transformer_class()
result_ast = ast.fix_missing_locations(transformer.visit(initial_ast))
self.assertASTEqual(result_ast, expected_ast)
def test_node_remove_single(self):
code = 'def func(arg) -> SomeType: ...'
expected = 'def func(arg): ...'
# Since `FunctionDef.returns` is defined as a single value, we test
# the `if isinstance(old_value, AST):` branch here.
class SomeTypeRemover(ast.NodeTransformer):
def visit_Name(self, node: ast.Name):
self.generic_visit(node)
if node.id == 'SomeType':
return None
return node
self.assertASTTransformation(SomeTypeRemover, code, expected)
def test_node_remove_from_list(self):
code = """
def func(arg):
print(arg)
yield arg
"""
expected = """
def func(arg):
print(arg)
"""
# Since `FunctionDef.body` is defined as a list, we test
# the `if isinstance(old_value, list):` branch here.
class YieldRemover(ast.NodeTransformer):
def visit_Expr(self, node: ast.Expr):
self.generic_visit(node)
if isinstance(node.value, ast.Yield):
return None # Remove `yield` from a function
return node
self.assertASTTransformation(YieldRemover, code, expected)
def test_node_return_list(self):
code = """
class DSL(Base, kw1=True): ...
"""
expected = """
class DSL(Base, kw1=True, kw2=True, kw3=False): ...
"""
class ExtendKeywords(ast.NodeTransformer):
def visit_keyword(self, node: ast.keyword):
self.generic_visit(node)
if node.arg == 'kw1':
return [
node,
ast.keyword('kw2', ast.Constant(True)),
ast.keyword('kw3', ast.Constant(False)),
]
return node
self.assertASTTransformation(ExtendKeywords, code, expected)
def test_node_mutate(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
log(arg)
"""
class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
node.func.id = 'log'
return node
self.assertASTTransformation(PrintToLog, code, expected)
def test_node_replace(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
logger.log(arg, debug=True)
"""
class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
return ast.Call(
func=ast.Attribute(
ast.Name('logger', ctx=ast.Load()),
attr='log',
ctx=ast.Load(),
),
args=node.args,
keywords=[ast.keyword('debug', ast.Constant(True))],
)
return node
self.assertASTTransformation(PrintToLog, code, expected)
class ASTConstructorTests(unittest.TestCase):
"""Test the autogenerated constructors for AST nodes."""
def test_FunctionDef(self):
args = ast.arguments()
self.assertEqual(args.args, [])
self.assertEqual(args.posonlyargs, [])
with self.assertWarnsRegex(DeprecationWarning,
r"FunctionDef\.__init__ missing 1 required positional argument: 'name'"):
node = ast.FunctionDef(args=args)
self.assertFalse(hasattr(node, "name"))
self.assertEqual(node.decorator_list, [])
node = ast.FunctionDef(name='foo', args=args)
self.assertEqual(node.name, 'foo')
self.assertEqual(node.decorator_list, [])
def test_expr_context(self):
name = ast.Name("x")
self.assertEqual(name.id, "x")
self.assertIsInstance(name.ctx, ast.Load)
name2 = ast.Name("x", ast.Store())
self.assertEqual(name2.id, "x")
self.assertIsInstance(name2.ctx, ast.Store)
name3 = ast.Name("x", ctx=ast.Del())
self.assertEqual(name3.id, "x")
self.assertIsInstance(name3.ctx, ast.Del)
with self.assertWarnsRegex(DeprecationWarning,
r"Name\.__init__ missing 1 required positional argument: 'id'"):
name3 = ast.Name()
def test_custom_subclass_with_no_fields(self):
class NoInit(ast.AST):
pass
obj = NoInit()
self.assertIsInstance(obj, NoInit)
self.assertEqual(obj.__dict__, {})
def test_fields_but_no_field_types(self):
class Fields(ast.AST):
_fields = ('a',)
obj = Fields()
with self.assertRaises(AttributeError):
obj.a
obj = Fields(a=1)
self.assertEqual(obj.a, 1)
def test_fields_and_types(self):
class FieldsAndTypes(ast.AST):
_fields = ('a',)
_field_types = {'a': int | None}
a: int | None = None
obj = FieldsAndTypes()
self.assertIs(obj.a, None)
obj = FieldsAndTypes(a=1)
self.assertEqual(obj.a, 1)
def test_custom_attributes(self):
class MyAttrs(ast.AST):
_attributes = ("a", "b")
obj = MyAttrs(a=1, b=2)
self.assertEqual(obj.a, 1)
self.assertEqual(obj.b, 2)
with self.assertWarnsRegex(DeprecationWarning,
r"MyAttrs.__init__ got an unexpected keyword argument 'c'."):
obj = MyAttrs(c=3)
def test_fields_and_types_no_default(self):
class FieldsAndTypesNoDefault(ast.AST):
_fields = ('a',)
_field_types = {'a': int}
with self.assertWarnsRegex(DeprecationWarning,
r"FieldsAndTypesNoDefault\.__init__ missing 1 required positional argument: 'a'\."):
obj = FieldsAndTypesNoDefault()
with self.assertRaises(AttributeError):
obj.a
obj = FieldsAndTypesNoDefault(a=1)
self.assertEqual(obj.a, 1)
def test_incomplete_field_types(self):
class MoreFieldsThanTypes(ast.AST):
_fields = ('a', 'b')
_field_types = {'a': int | None}
a: int | None = None
b: int | None = None
with self.assertWarnsRegex(
DeprecationWarning,
r"Field 'b' is missing from MoreFieldsThanTypes\._field_types"
):
obj = MoreFieldsThanTypes()
self.assertIs(obj.a, None)
self.assertIs(obj.b, None)
obj = MoreFieldsThanTypes(a=1, b=2)
self.assertEqual(obj.a, 1)
self.assertEqual(obj.b, 2)
def test_complete_field_types(self):
class _AllFieldTypes(ast.AST):
_fields = ('a', 'b')
_field_types = {'a': int | None, 'b': list[str]}
# This must be set explicitly
a: int | None = None
# This will add an implicit empty list default
b: list[str]
obj = _AllFieldTypes()
self.assertIs(obj.a, None)
self.assertEqual(obj.b, [])
@support.cpython_only
class ModuleStateTests(unittest.TestCase):
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
def check_ast_module(self):
# Check that the _ast module still works as expected
code = 'x + 1'
filename = '<string>'
mode = 'eval'
# Create _ast.AST subclasses instances
ast_tree = compile(code, filename, mode, flags=ast.PyCF_ONLY_AST)
# Call PyAST_Check()
code = compile(ast_tree, filename, mode)
self.assertIsInstance(code, types.CodeType)
def test_reload_module(self):
# bpo-41194: Importing the _ast module twice must not crash.
with support.swap_item(sys.modules, '_ast', None):
del sys.modules['_ast']
import _ast as ast1
del sys.modules['_ast']
import _ast as ast2
self.check_ast_module()
# Unloading the two _ast module instances must not crash.
del ast1
del ast2
support.gc_collect()
self.check_ast_module()
def test_sys_modules(self):
# bpo-41631: Test reproducing a Mercurial crash when PyAST_Check()
# imported the _ast module internally.
lazy_mod = object()
def my_import(name, *args, **kw):
sys.modules[name] = lazy_mod
return lazy_mod
with support.swap_item(sys.modules, '_ast', None):
del sys.modules['_ast']
with support.swap_attr(builtins, '__import__', my_import):
# Test that compile() does not import the _ast module
self.check_ast_module()
self.assertNotIn('_ast', sys.modules)
# Sanity check of the test itself
import _ast
self.assertIs(_ast, lazy_mod)
def test_subinterpreter(self):
# bpo-41631: Importing and using the _ast module in a subinterpreter
# must not crash.
code = dedent('''
import _ast
import ast
import gc
import sys
import types
# Create _ast.AST subclasses instances and call PyAST_Check()
ast_tree = compile('x+1', '<string>', 'eval',
flags=ast.PyCF_ONLY_AST)
code = compile(ast_tree, 'string', 'eval')
if not isinstance(code, types.CodeType):
raise AssertionError
# Unloading the _ast module must not crash.
del ast, _ast
del sys.modules['ast'], sys.modules['_ast']
gc.collect()
''')
res = support.run_in_subinterp(code)
self.assertEqual(res, 0)
class ASTMainTests(unittest.TestCase):
# Tests `ast.main()` function.
def test_cli_file_input(self):
code = "print(1, 2, 3)"
expected = ast.dump(ast.parse(code), indent=3)
with os_helper.temp_dir() as tmp_dir:
filename = os.path.join(tmp_dir, "test_module.py")
with open(filename, 'w', encoding='utf-8') as f:
f.write(code)
res, _ = script_helper.run_python_until_end("-m", "ast", filename)
self.assertEqual(res.err, b"")
self.assertEqual(expected.splitlines(),
res.out.decode("utf8").splitlines())
self.assertEqual(res.rc, 0)
class ASTOptimiziationTests(unittest.TestCase):
binop = {
"+": ast.Add(),
"-": ast.Sub(),
"*": ast.Mult(),
"/": ast.Div(),
"%": ast.Mod(),
"<<": ast.LShift(),
">>": ast.RShift(),
"|": ast.BitOr(),
"^": ast.BitXor(),
"&": ast.BitAnd(),
"//": ast.FloorDiv(),
"**": ast.Pow(),
}
unaryop = {
"~": ast.Invert(),
"+": ast.UAdd(),
"-": ast.USub(),
}
def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])
def wrap_statement(self, statement):
return ast.Module(body=[statement])
def assert_ast(self, code, non_optimized_target, optimized_target):
non_optimized_tree = ast.parse(code, optimize=-1)
optimized_tree = ast.parse(code, optimize=1)
# Is a non-optimized tree equal to a non-optimized target?
self.assertTrue(
ast.compare(non_optimized_tree, non_optimized_target),
f"{ast.dump(non_optimized_target)} must equal "
f"{ast.dump(non_optimized_tree)}",
)
# Is a optimized tree equal to a non-optimized target?
self.assertFalse(
ast.compare(optimized_tree, non_optimized_target),
f"{ast.dump(non_optimized_target)} must not equal "
f"{ast.dump(non_optimized_tree)}"
)
# Is a optimized tree is equal to an optimized target?
self.assertTrue(
ast.compare(optimized_tree, optimized_target),
f"{ast.dump(optimized_target)} must equal "
f"{ast.dump(optimized_tree)}",
)
def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)
def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()
for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(self.create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)
# Multiplication of constant tuples must be folded
code = "(1,) * 3"
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))
self.assert_ast(code, non_optimized_target, optimized_target)
def test_folding_unaryop(self):
code = "%s1"
operators = self.unaryop.keys()
def create_unaryop(operand):
return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))
for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_unaryop(op))
optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))
with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_not(self):
code = "not (1 %s (1,))"
operators = {
"in": ast.In(),
"is": ast.Is(),
}
opt_operators = {
"is": ast.IsNot(),
"in": ast.NotIn(),
}
def create_notop(operand):
return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
left=ast.Constant(value=1),
ops=[operators[operand]],
comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
))
for op in operators.keys():
result_code = code % op
non_optimized_target = self.wrap_expr(create_notop(op))
optimized_target = self.wrap_expr(
ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
)
with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_format(self):
code = "'%s' % (a,)"
non_optimized_target = self.wrap_expr(
ast.BinOp(
left=ast.Constant(value="%s"),
op=ast.Mod(),
right=ast.Tuple(elts=[ast.Name(id='a')]))
)
optimized_target = self.wrap_expr(
ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id='a'), conversion=115)
]
)
)
self.assert_ast(code, non_optimized_target, optimized_target)
def test_folding_tuple(self):
code = "(1,)"
non_optimized_target = self.wrap_expr(ast.Tuple(elts=[ast.Constant(1)]))
optimized_target = self.wrap_expr(ast.Constant(value=(1,)))
self.assert_ast(code, non_optimized_target, optimized_target)
def test_folding_comparator(self):
code = "1 %s %s1%s"
operators = [("in", ast.In()), ("not in", ast.NotIn())]
braces = [
("[", "]", ast.List, (1,)),
("{", "}", ast.Set, frozenset({1})),
]
for left, right, non_optimized_comparator, optimized_comparator in braces:
for op, node in operators:
non_optimized_target = self.wrap_expr(ast.Compare(
left=ast.Constant(1), ops=[node],
comparators=[non_optimized_comparator(elts=[ast.Constant(1)])]
))
optimized_target = self.wrap_expr(ast.Compare(
left=ast.Constant(1), ops=[node],
comparators=[ast.Constant(value=optimized_comparator)]
))
self.assert_ast(code % (op, left, right), non_optimized_target, optimized_target)
def test_folding_iter(self):
code = "for _ in %s1%s: pass"
braces = [
("[", "]", ast.List, (1,)),
("{", "}", ast.Set, frozenset({1})),
]
for left, right, ast_cls, optimized_iter in braces:
non_optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast_cls(elts=[ast.Constant(1)]),
body=[ast.Pass()]
))
optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast.Constant(value=optimized_iter),
body=[ast.Pass()]
))
self.assert_ast(code % (left, right), non_optimized_target, optimized_target)
def test_folding_subscript(self):
code = "(1,)[0]"
non_optimized_target = self.wrap_expr(
ast.Subscript(value=ast.Tuple(elts=[ast.Constant(value=1)]), slice=ast.Constant(value=0))
)
optimized_target = self.wrap_expr(ast.Constant(value=1))
self.assert_ast(code, non_optimized_target, optimized_target)
def test_folding_type_param_in_function_def(self):
code = "def foo[%s = 1 + 1](): pass"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
)
)
non_optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_type_param_in_class_def(self):
code = "class foo[%s = 1 + 1]: pass"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
)
)
non_optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_type_param_in_type_alias(self):
code = "type foo[%s = 1 + 1] = 1"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=ast.Constant(2))],
value=ast.Constant(value=1),
)
)
non_optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=unoptimized_binop)],
value=ast.Constant(value=1),
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
if __name__ == "__main__":
unittest.main()