cpython/Lib/test/test_type_params.py

1397 lines
44 KiB
Python

import asyncio
import textwrap
import types
import unittest
import pickle
import weakref
from test.support import requires_working_socket, check_syntax_error, run_code
from typing import Generic, NoDefault, Sequence, TypeVar, TypeVarTuple, ParamSpec, get_args
class TypeParamsInvalidTest(unittest.TestCase):
def test_name_collisions(self):
check_syntax_error(self, 'def func[**A, A](): ...', "duplicate type parameter 'A'")
check_syntax_error(self, 'def func[A, *A](): ...', "duplicate type parameter 'A'")
check_syntax_error(self, 'def func[*A, **A](): ...', "duplicate type parameter 'A'")
check_syntax_error(self, 'class C[**A, A](): ...', "duplicate type parameter 'A'")
check_syntax_error(self, 'class C[A, *A](): ...', "duplicate type parameter 'A'")
check_syntax_error(self, 'class C[*A, **A](): ...', "duplicate type parameter 'A'")
def test_name_non_collision_02(self):
ns = run_code("""def func[A](A): return A""")
func = ns["func"]
self.assertEqual(func(1), 1)
A, = func.__type_params__
self.assertEqual(A.__name__, "A")
def test_name_non_collision_03(self):
ns = run_code("""def func[A](*A): return A""")
func = ns["func"]
self.assertEqual(func(1), (1,))
A, = func.__type_params__
self.assertEqual(A.__name__, "A")
def test_name_non_collision_04(self):
# Mangled names should not cause a conflict.
ns = run_code("""
class ClassA:
def func[__A](self, __A): return __A
"""
)
cls = ns["ClassA"]
self.assertEqual(cls().func(1), 1)
A, = cls.func.__type_params__
self.assertEqual(A.__name__, "__A")
def test_name_non_collision_05(self):
ns = run_code("""
class ClassA:
def func[_ClassA__A](self, __A): return __A
"""
)
cls = ns["ClassA"]
self.assertEqual(cls().func(1), 1)
A, = cls.func.__type_params__
self.assertEqual(A.__name__, "_ClassA__A")
def test_name_non_collision_06(self):
ns = run_code("""
class ClassA[X]:
def func(self, X): return X
"""
)
cls = ns["ClassA"]
self.assertEqual(cls().func(1), 1)
X, = cls.__type_params__
self.assertEqual(X.__name__, "X")
def test_name_non_collision_07(self):
ns = run_code("""
class ClassA[X]:
def func(self):
X = 1
return X
"""
)
cls = ns["ClassA"]
self.assertEqual(cls().func(), 1)
X, = cls.__type_params__
self.assertEqual(X.__name__, "X")
def test_name_non_collision_08(self):
ns = run_code("""
class ClassA[X]:
def func(self):
return [X for X in [1, 2]]
"""
)
cls = ns["ClassA"]
self.assertEqual(cls().func(), [1, 2])
X, = cls.__type_params__
self.assertEqual(X.__name__, "X")
def test_name_non_collision_9(self):
ns = run_code("""
class ClassA[X]:
def func[X](self):
...
"""
)
cls = ns["ClassA"]
outer_X, = cls.__type_params__
inner_X, = cls.func.__type_params__
self.assertEqual(outer_X.__name__, "X")
self.assertEqual(inner_X.__name__, "X")
self.assertIsNot(outer_X, inner_X)
def test_name_non_collision_10(self):
ns = run_code("""
class ClassA[X]:
X: int
"""
)
cls = ns["ClassA"]
X, = cls.__type_params__
self.assertEqual(X.__name__, "X")
self.assertIs(cls.__annotations__["X"], int)
def test_name_non_collision_13(self):
ns = run_code("""
X = 1
def outer():
def inner[X]():
global X
X = 2
return inner
"""
)
self.assertEqual(ns["X"], 1)
outer = ns["outer"]
outer()()
self.assertEqual(ns["X"], 2)
def test_disallowed_expressions(self):
check_syntax_error(self, "type X = (yield)")
check_syntax_error(self, "type X = (yield from x)")
check_syntax_error(self, "type X = (await 42)")
check_syntax_error(self, "async def f(): type X = (yield)")
check_syntax_error(self, "type X = (y := 3)")
check_syntax_error(self, "class X[T: (yield)]: pass")
check_syntax_error(self, "class X[T: (yield from x)]: pass")
check_syntax_error(self, "class X[T: (await 42)]: pass")
check_syntax_error(self, "class X[T: (y := 3)]: pass")
check_syntax_error(self, "class X[T](y := Sequence[T]): pass")
check_syntax_error(self, "def f[T](y: (x := Sequence[T])): pass")
check_syntax_error(self, "class X[T]([(x := 3) for _ in range(2)] and B): pass")
check_syntax_error(self, "def f[T: [(x := 3) for _ in range(2)]](): pass")
check_syntax_error(self, "type T = [(x := 3) for _ in range(2)]")
def test_incorrect_mro_explicit_object(self):
with self.assertRaisesRegex(TypeError, r"\(MRO\) for bases object, Generic"):
class My[X](object): ...
class TypeParamsNonlocalTest(unittest.TestCase):
def test_nonlocal_disallowed_01(self):
code = """
def outer():
X = 1
def inner[X]():
nonlocal X
return X
"""
check_syntax_error(self, code)
def test_nonlocal_disallowed_02(self):
code = """
def outer2[T]():
def inner1():
nonlocal T
"""
check_syntax_error(self, textwrap.dedent(code))
def test_nonlocal_disallowed_03(self):
code = """
class Cls[T]:
nonlocal T
"""
check_syntax_error(self, textwrap.dedent(code))
def test_nonlocal_allowed(self):
code = """
def func[T]():
T = "func"
def inner():
nonlocal T
T = "inner"
inner()
assert T == "inner"
"""
ns = run_code(code)
func = ns["func"]
T, = func.__type_params__
self.assertEqual(T.__name__, "T")
class TypeParamsAccessTest(unittest.TestCase):
def test_class_access_01(self):
ns = run_code("""
class ClassA[A, B](dict[A, B]):
...
"""
)
cls = ns["ClassA"]
A, B = cls.__type_params__
self.assertEqual(types.get_original_bases(cls), (dict[A, B], Generic[A, B]))
def test_class_access_02(self):
ns = run_code("""
class MyMeta[A, B](type): ...
class ClassA[A, B](metaclass=MyMeta[A, B]):
...
"""
)
meta = ns["MyMeta"]
cls = ns["ClassA"]
A1, B1 = meta.__type_params__
A2, B2 = cls.__type_params__
self.assertIsNot(A1, A2)
self.assertIsNot(B1, B2)
self.assertIs(type(cls), meta)
def test_class_access_03(self):
code = """
def my_decorator(a):
...
@my_decorator(A)
class ClassA[A, B]():
...
"""
with self.assertRaisesRegex(NameError, "name 'A' is not defined"):
run_code(code)
def test_function_access_01(self):
ns = run_code("""
def func[A, B](a: dict[A, B]):
...
"""
)
func = ns["func"]
A, B = func.__type_params__
self.assertEqual(func.__annotations__["a"], dict[A, B])
def test_function_access_02(self):
code = """
def func[A](a = list[A]()):
...
"""
with self.assertRaisesRegex(NameError, "name 'A' is not defined"):
run_code(code)
def test_function_access_03(self):
code = """
def my_decorator(a):
...
@my_decorator(A)
def func[A]():
...
"""
with self.assertRaisesRegex(NameError, "name 'A' is not defined"):
run_code(code)
def test_method_access_01(self):
ns = run_code("""
class ClassA:
x = int
def func[T](self, a: x, b: T):
...
"""
)
cls = ns["ClassA"]
self.assertIs(cls.func.__annotations__["a"], int)
T, = cls.func.__type_params__
self.assertIs(cls.func.__annotations__["b"], T)
def test_nested_access_01(self):
ns = run_code("""
class ClassA[A]:
def funcB[B](self):
class ClassC[C]:
def funcD[D](self):
return lambda: (A, B, C, D)
return ClassC
"""
)
cls = ns["ClassA"]
A, = cls.__type_params__
B, = cls.funcB.__type_params__
classC = cls().funcB()
C, = classC.__type_params__
D, = classC.funcD.__type_params__
self.assertEqual(classC().funcD()(), (A, B, C, D))
def test_out_of_scope_01(self):
code = """
class ClassA[T]: ...
x = T
"""
with self.assertRaisesRegex(NameError, "name 'T' is not defined"):
run_code(code)
def test_out_of_scope_02(self):
code = """
class ClassA[A]:
def funcB[B](self): ...
x = B
"""
with self.assertRaisesRegex(NameError, "name 'B' is not defined"):
run_code(code)
def test_class_scope_interaction_01(self):
ns = run_code("""
class C:
x = 1
def method[T](self, arg: x): pass
""")
cls = ns["C"]
self.assertEqual(cls.method.__annotations__["arg"], 1)
def test_class_scope_interaction_02(self):
ns = run_code("""
class C:
class Base: pass
class Child[T](Base): pass
""")
cls = ns["C"]
self.assertEqual(cls.Child.__bases__, (cls.Base, Generic))
T, = cls.Child.__type_params__
self.assertEqual(types.get_original_bases(cls.Child), (cls.Base, Generic[T]))
def test_class_deref(self):
ns = run_code("""
class C[T]:
T = "class"
type Alias = T
""")
cls = ns["C"]
self.assertEqual(cls.Alias.__value__, "class")
def test_shadowing_nonlocal(self):
ns = run_code("""
def outer[T]():
T = "outer"
def inner():
nonlocal T
T = "inner"
return T
return lambda: T, inner
""")
outer = ns["outer"]
T, = outer.__type_params__
self.assertEqual(T.__name__, "T")
getter, inner = outer()
self.assertEqual(getter(), "outer")
self.assertEqual(inner(), "inner")
self.assertEqual(getter(), "inner")
def test_reference_previous_typevar(self):
def func[S, T: Sequence[S]]():
pass
S, T = func.__type_params__
self.assertEqual(T.__bound__, Sequence[S])
def test_super(self):
class Base:
def meth(self):
return "base"
class Child(Base):
# Having int in the annotation ensures the class gets cells for both
# __class__ and __classdict__
def meth[T](self, arg: int) -> T:
return super().meth() + "child"
c = Child()
self.assertEqual(c.meth(1), "basechild")
def test_type_alias_containing_lambda(self):
type Alias[T] = lambda: T
T, = Alias.__type_params__
self.assertIs(Alias.__value__(), T)
def test_class_base_containing_lambda(self):
# Test that scopes nested inside hidden functions work correctly
outer_var = "outer"
class Base[T]: ...
class Child[T](Base[lambda: (int, outer_var, T)]): ...
base, _ = types.get_original_bases(Child)
func, = get_args(base)
T, = Child.__type_params__
self.assertEqual(func(), (int, "outer", T))
def test_comprehension_01(self):
type Alias[T: ([T for T in (T, [1])[1]], T)] = [T for T in T.__name__]
self.assertEqual(Alias.__value__, ["T"])
T, = Alias.__type_params__
self.assertEqual(T.__constraints__, ([1], T))
def test_comprehension_02(self):
type Alias[T: [lambda: T for T in (T, [1])[1]]] = [lambda: T for T in T.__name__]
func, = Alias.__value__
self.assertEqual(func(), "T")
T, = Alias.__type_params__
func, = T.__bound__
self.assertEqual(func(), 1)
def test_comprehension_03(self):
def F[T: [lambda: T for T in (T, [1])[1]]](): return [lambda: T for T in T.__name__]
func, = F()
self.assertEqual(func(), "T")
T, = F.__type_params__
func, = T.__bound__
self.assertEqual(func(), 1)
def test_gen_exp_in_nested_class(self):
code = """
from test.test_type_params import make_base
class C[T]:
T = "class"
class Inner(make_base(T for _ in (1,)), make_base(T)):
pass
"""
C = run_code(code)["C"]
T, = C.__type_params__
base1, base2 = C.Inner.__bases__
self.assertEqual(list(base1.__arg__), [T])
self.assertEqual(base2.__arg__, "class")
def test_gen_exp_in_nested_generic_class(self):
code = """
from test.test_type_params import make_base
class C[T]:
T = "class"
class Inner[U](make_base(T for _ in (1,)), make_base(T)):
pass
"""
ns = run_code(code)
inner = ns["C"].Inner
base1, base2, _ = inner.__bases__
self.assertEqual(list(base1.__arg__), [ns["C"].__type_params__[0]])
self.assertEqual(base2.__arg__, "class")
def test_listcomp_in_nested_class(self):
code = """
from test.test_type_params import make_base
class C[T]:
T = "class"
class Inner(make_base([T for _ in (1,)]), make_base(T)):
pass
"""
C = run_code(code)["C"]
T, = C.__type_params__
base1, base2 = C.Inner.__bases__
self.assertEqual(base1.__arg__, [T])
self.assertEqual(base2.__arg__, "class")
def test_listcomp_in_nested_generic_class(self):
code = """
from test.test_type_params import make_base
class C[T]:
T = "class"
class Inner[U](make_base([T for _ in (1,)]), make_base(T)):
pass
"""
ns = run_code(code)
inner = ns["C"].Inner
base1, base2, _ = inner.__bases__
self.assertEqual(base1.__arg__, [ns["C"].__type_params__[0]])
self.assertEqual(base2.__arg__, "class")
def test_gen_exp_in_generic_method(self):
code = """
class C[T]:
T = "class"
def meth[U](x: (T for _ in (1,)), y: T):
pass
"""
ns = run_code(code)
meth = ns["C"].meth
self.assertEqual(list(meth.__annotations__["x"]), [ns["C"].__type_params__[0]])
self.assertEqual(meth.__annotations__["y"], "class")
def test_nested_scope_in_generic_alias(self):
code = """
T = "global"
class C:
T = "class"
{}
"""
cases = [
"type Alias[T] = (T for _ in (1,))",
"type Alias = (T for _ in (1,))",
"type Alias[T] = [T for _ in (1,)]",
"type Alias = [T for _ in (1,)]",
]
for case in cases:
with self.subTest(case=case):
ns = run_code(code.format(case))
alias = ns["C"].Alias
value = list(alias.__value__)[0]
if alias.__type_params__:
self.assertIs(value, alias.__type_params__[0])
else:
self.assertEqual(value, "global")
def test_lambda_in_alias_in_class(self):
code = """
T = "global"
class C:
T = "class"
type Alias = lambda: T
"""
C = run_code(code)["C"]
self.assertEqual(C.Alias.__value__(), "global")
def test_lambda_in_alias_in_generic_class(self):
code = """
class C[T]:
T = "class"
type Alias = lambda: T
"""
C = run_code(code)["C"]
self.assertIs(C.Alias.__value__(), C.__type_params__[0])
def test_lambda_in_generic_alias_in_class(self):
# A lambda nested in the alias cannot see the class scope, but can see
# a surrounding annotation scope.
code = """
T = U = "global"
class C:
T = "class"
U = "class"
type Alias[T] = lambda: (T, U)
"""
C = run_code(code)["C"]
T, U = C.Alias.__value__()
self.assertIs(T, C.Alias.__type_params__[0])
self.assertEqual(U, "global")
def test_lambda_in_generic_alias_in_generic_class(self):
# A lambda nested in the alias cannot see the class scope, but can see
# a surrounding annotation scope.
code = """
class C[T, U]:
T = "class"
U = "class"
type Alias[T] = lambda: (T, U)
"""
C = run_code(code)["C"]
T, U = C.Alias.__value__()
self.assertIs(T, C.Alias.__type_params__[0])
self.assertIs(U, C.__type_params__[1])
def test_type_special_case(self):
# https://github.com/python/cpython/issues/119011
self.assertEqual(type.__type_params__, ())
self.assertEqual(object.__type_params__, ())
def make_base(arg):
class Base:
__arg__ = arg
return Base
def global_generic_func[T]():
pass
class GlobalGenericClass[T]:
pass
class TypeParamsLazyEvaluationTest(unittest.TestCase):
def test_qualname(self):
class Foo[T]:
pass
def func[T]():
pass
self.assertEqual(Foo.__qualname__, "TypeParamsLazyEvaluationTest.test_qualname.<locals>.Foo")
self.assertEqual(func.__qualname__, "TypeParamsLazyEvaluationTest.test_qualname.<locals>.func")
self.assertEqual(global_generic_func.__qualname__, "global_generic_func")
self.assertEqual(GlobalGenericClass.__qualname__, "GlobalGenericClass")
def test_recursive_class(self):
class Foo[T: Foo, U: (Foo, Foo)]:
pass
type_params = Foo.__type_params__
self.assertEqual(len(type_params), 2)
self.assertEqual(type_params[0].__name__, "T")
self.assertIs(type_params[0].__bound__, Foo)
self.assertEqual(type_params[0].__constraints__, ())
self.assertIs(type_params[0].__default__, NoDefault)
self.assertEqual(type_params[1].__name__, "U")
self.assertIs(type_params[1].__bound__, None)
self.assertEqual(type_params[1].__constraints__, (Foo, Foo))
self.assertIs(type_params[1].__default__, NoDefault)
def test_evaluation_error(self):
class Foo[T: Undefined, U: (Undefined,)]:
pass
type_params = Foo.__type_params__
with self.assertRaises(NameError):
type_params[0].__bound__
self.assertEqual(type_params[0].__constraints__, ())
self.assertIs(type_params[1].__bound__, None)
self.assertIs(type_params[0].__default__, NoDefault)
self.assertIs(type_params[1].__default__, NoDefault)
with self.assertRaises(NameError):
type_params[1].__constraints__
Undefined = "defined"
self.assertEqual(type_params[0].__bound__, "defined")
self.assertEqual(type_params[0].__constraints__, ())
self.assertIs(type_params[1].__bound__, None)
self.assertEqual(type_params[1].__constraints__, ("defined",))
class TypeParamsClassScopeTest(unittest.TestCase):
def test_alias(self):
class X:
T = int
type U = T
self.assertIs(X.U.__value__, int)
ns = run_code("""
glb = "global"
class X:
cls = "class"
type U = (glb, cls)
""")
cls = ns["X"]
self.assertEqual(cls.U.__value__, ("global", "class"))
def test_bound(self):
class X:
T = int
def foo[U: T](self): ...
self.assertIs(X.foo.__type_params__[0].__bound__, int)
ns = run_code("""
glb = "global"
class X:
cls = "class"
def foo[T: glb, U: cls](self): ...
""")
cls = ns["X"]
T, U = cls.foo.__type_params__
self.assertEqual(T.__bound__, "global")
self.assertEqual(U.__bound__, "class")
def test_modified_later(self):
class X:
T = int
def foo[U: T](self): ...
type Alias = T
X.T = float
self.assertIs(X.foo.__type_params__[0].__bound__, float)
self.assertIs(X.Alias.__value__, float)
def test_binding_uses_global(self):
ns = run_code("""
x = "global"
def outer():
x = "nonlocal"
class Cls:
type Alias = x
val = Alias.__value__
def meth[T: x](self, arg: x): ...
bound = meth.__type_params__[0].__bound__
annotation = meth.__annotations__["arg"]
x = "class"
return Cls
""")
cls = ns["outer"]()
self.assertEqual(cls.val, "global")
self.assertEqual(cls.bound, "global")
self.assertEqual(cls.annotation, "global")
def test_no_binding_uses_nonlocal(self):
ns = run_code("""
x = "global"
def outer():
x = "nonlocal"
class Cls:
type Alias = x
val = Alias.__value__
def meth[T: x](self, arg: x): ...
bound = meth.__type_params__[0].__bound__
return Cls
""")
cls = ns["outer"]()
self.assertEqual(cls.val, "nonlocal")
self.assertEqual(cls.bound, "nonlocal")
self.assertEqual(cls.meth.__annotations__["arg"], "nonlocal")
def test_explicit_global(self):
ns = run_code("""
x = "global"
def outer():
x = "nonlocal"
class Cls:
global x
type Alias = x
Cls.x = "class"
return Cls
""")
cls = ns["outer"]()
self.assertEqual(cls.Alias.__value__, "global")
def test_explicit_global_with_no_static_bound(self):
ns = run_code("""
def outer():
class Cls:
global x
type Alias = x
Cls.x = "class"
return Cls
""")
ns["x"] = "global"
cls = ns["outer"]()
self.assertEqual(cls.Alias.__value__, "global")
def test_explicit_global_with_assignment(self):
ns = run_code("""
x = "global"
def outer():
x = "nonlocal"
class Cls:
global x
type Alias = x
x = "global from class"
Cls.x = "class"
return Cls
""")
cls = ns["outer"]()
self.assertEqual(cls.Alias.__value__, "global from class")
def test_explicit_nonlocal(self):
ns = run_code("""
x = "global"
def outer():
x = "nonlocal"
class Cls:
nonlocal x
type Alias = x
x = "class"
return Cls
""")
cls = ns["outer"]()
self.assertEqual(cls.Alias.__value__, "class")
def test_nested_free(self):
ns = run_code("""
def f():
T = str
class C:
T = int
class D[U](T):
x = T
return C
""")
C = ns["f"]()
self.assertIn(int, C.D.__bases__)
self.assertIs(C.D.x, str)
class DynamicClassTest(unittest.TestCase):
def _set_type_params(self, ns, params):
ns['__type_params__'] = params
def test_types_new_class_with_callback(self):
T = TypeVar('T', infer_variance=True)
Klass = types.new_class('Klass', (Generic[T],), {},
lambda ns: self._set_type_params(ns, (T,)))
self.assertEqual(Klass.__bases__, (Generic,))
self.assertEqual(Klass.__orig_bases__, (Generic[T],))
self.assertEqual(Klass.__type_params__, (T,))
self.assertEqual(Klass.__parameters__, (T,))
def test_types_new_class_no_callback(self):
T = TypeVar('T', infer_variance=True)
Klass = types.new_class('Klass', (Generic[T],), {})
self.assertEqual(Klass.__bases__, (Generic,))
self.assertEqual(Klass.__orig_bases__, (Generic[T],))
self.assertEqual(Klass.__type_params__, ()) # must be explicitly set
self.assertEqual(Klass.__parameters__, (T,))
class TypeParamsManglingTest(unittest.TestCase):
def test_mangling(self):
class Foo[__T]:
param = __T
def meth[__U](self, arg: __T, arg2: __U):
return (__T, __U)
type Alias[__V] = (__T, __V)
T = Foo.__type_params__[0]
self.assertEqual(T.__name__, "__T")
U = Foo.meth.__type_params__[0]
self.assertEqual(U.__name__, "__U")
V = Foo.Alias.__type_params__[0]
self.assertEqual(V.__name__, "__V")
anno = Foo.meth.__annotations__
self.assertIs(anno["arg"], T)
self.assertIs(anno["arg2"], U)
self.assertEqual(Foo().meth(1, 2), (T, U))
self.assertEqual(Foo.Alias.__value__, (T, V))
def test_no_leaky_mangling_in_module(self):
ns = run_code("""
__before = "before"
class X[T]: pass
__after = "after"
""")
self.assertEqual(ns["__before"], "before")
self.assertEqual(ns["__after"], "after")
def test_no_leaky_mangling_in_function(self):
ns = run_code("""
def f():
class X[T]: pass
_X_foo = 2
__foo = 1
assert locals()['__foo'] == 1
return __foo
""")
self.assertEqual(ns["f"](), 1)
def test_no_leaky_mangling_in_class(self):
ns = run_code("""
class Outer:
__before = "before"
class Inner[T]:
__x = "inner"
__after = "after"
""")
Outer = ns["Outer"]
self.assertEqual(Outer._Outer__before, "before")
self.assertEqual(Outer.Inner._Inner__x, "inner")
self.assertEqual(Outer._Outer__after, "after")
def test_no_mangling_in_bases(self):
ns = run_code("""
class __Base:
def __init_subclass__(self, **kwargs):
self.kwargs = kwargs
class Derived[T](__Base, __kwarg=1):
pass
""")
Derived = ns["Derived"]
self.assertEqual(Derived.__bases__, (ns["__Base"], Generic))
self.assertEqual(Derived.kwargs, {"__kwarg": 1})
def test_no_mangling_in_nested_scopes(self):
ns = run_code("""
from test.test_type_params import make_base
class __X:
pass
class Y[T: __X](
make_base(lambda: __X),
# doubly nested scope
make_base(lambda: (lambda: __X)),
# list comprehension
make_base([__X for _ in (1,)]),
# genexp
make_base(__X for _ in (1,)),
):
pass
""")
Y = ns["Y"]
T, = Y.__type_params__
self.assertIs(T.__bound__, ns["__X"])
base0 = Y.__bases__[0]
self.assertIs(base0.__arg__(), ns["__X"])
base1 = Y.__bases__[1]
self.assertIs(base1.__arg__()(), ns["__X"])
base2 = Y.__bases__[2]
self.assertEqual(base2.__arg__, [ns["__X"]])
base3 = Y.__bases__[3]
self.assertEqual(list(base3.__arg__), [ns["__X"]])
def test_type_params_are_mangled(self):
ns = run_code("""
from test.test_type_params import make_base
class Foo[__T, __U: __T](make_base(__T), make_base(lambda: __T)):
param = __T
""")
Foo = ns["Foo"]
T, U = Foo.__type_params__
self.assertEqual(T.__name__, "__T")
self.assertEqual(U.__name__, "__U")
self.assertIs(U.__bound__, T)
self.assertIs(Foo.param, T)
base1, base2, *_ = Foo.__bases__
self.assertIs(base1.__arg__, T)
self.assertIs(base2.__arg__(), T)
class TypeParamsComplexCallsTest(unittest.TestCase):
def test_defaults(self):
# Generic functions with both defaults and kwdefaults trigger a specific code path
# in the compiler.
def func[T](a: T = "a", *, b: T = "b"):
return (a, b)
T, = func.__type_params__
self.assertIs(func.__annotations__["a"], T)
self.assertIs(func.__annotations__["b"], T)
self.assertEqual(func(), ("a", "b"))
self.assertEqual(func(1), (1, "b"))
self.assertEqual(func(b=2), ("a", 2))
def test_complex_base(self):
class Base:
def __init_subclass__(cls, **kwargs) -> None:
cls.kwargs = kwargs
kwargs = {"c": 3}
# Base classes with **kwargs trigger a different code path in the compiler.
class C[T](Base, a=1, b=2, **kwargs):
pass
T, = C.__type_params__
self.assertEqual(T.__name__, "T")
self.assertEqual(C.kwargs, {"a": 1, "b": 2, "c": 3})
bases = (Base,)
class C2[T](*bases, **kwargs):
pass
T, = C2.__type_params__
self.assertEqual(T.__name__, "T")
self.assertEqual(C2.kwargs, {"c": 3})
class TypeParamsTraditionalTypeVarsTest(unittest.TestCase):
def test_traditional_01(self):
code = """
from typing import Generic
class ClassA[T](Generic[T]): ...
"""
with self.assertRaisesRegex(TypeError, r"Cannot inherit from Generic\[...\] multiple times."):
run_code(code)
def test_traditional_02(self):
from typing import TypeVar
S = TypeVar("S")
with self.assertRaises(TypeError):
class ClassA[T](dict[T, S]): ...
def test_traditional_03(self):
# This does not generate a runtime error, but it should be
# flagged as an error by type checkers.
from typing import TypeVar
S = TypeVar("S")
def func[T](a: T, b: S) -> T | S:
return a
class TypeParamsTypeVarTest(unittest.TestCase):
def test_typevar_01(self):
def func1[A: str, B: str | int, C: (int, str)]():
return (A, B, C)
a, b, c = func1()
self.assertIsInstance(a, TypeVar)
self.assertEqual(a.__bound__, str)
self.assertTrue(a.__infer_variance__)
self.assertFalse(a.__covariant__)
self.assertFalse(a.__contravariant__)
self.assertIsInstance(b, TypeVar)
self.assertEqual(b.__bound__, str | int)
self.assertTrue(b.__infer_variance__)
self.assertFalse(b.__covariant__)
self.assertFalse(b.__contravariant__)
self.assertIsInstance(c, TypeVar)
self.assertEqual(c.__bound__, None)
self.assertEqual(c.__constraints__, (int, str))
self.assertTrue(c.__infer_variance__)
self.assertFalse(c.__covariant__)
self.assertFalse(c.__contravariant__)
def test_typevar_generator(self):
def get_generator[A]():
def generator1[C]():
yield C
def generator2[B]():
yield A
yield B
yield from generator1()
return generator2
gen = get_generator()
a, b, c = [x for x in gen()]
self.assertIsInstance(a, TypeVar)
self.assertEqual(a.__name__, "A")
self.assertIsInstance(b, TypeVar)
self.assertEqual(b.__name__, "B")
self.assertIsInstance(c, TypeVar)
self.assertEqual(c.__name__, "C")
@requires_working_socket()
def test_typevar_coroutine(self):
def get_coroutine[A]():
async def coroutine[B]():
return (A, B)
return coroutine
co = get_coroutine()
self.addCleanup(asyncio.set_event_loop_policy, None)
a, b = asyncio.run(co())
self.assertIsInstance(a, TypeVar)
self.assertEqual(a.__name__, "A")
self.assertIsInstance(b, TypeVar)
self.assertEqual(b.__name__, "B")
class TypeParamsTypeVarTupleTest(unittest.TestCase):
def test_typevartuple_01(self):
code = """def func1[*A: str](): pass"""
check_syntax_error(self, code, "cannot use bound with TypeVarTuple")
code = """def func1[*A: (int, str)](): pass"""
check_syntax_error(self, code, "cannot use constraints with TypeVarTuple")
code = """class X[*A: str]: pass"""
check_syntax_error(self, code, "cannot use bound with TypeVarTuple")
code = """class X[*A: (int, str)]: pass"""
check_syntax_error(self, code, "cannot use constraints with TypeVarTuple")
code = """type X[*A: str] = int"""
check_syntax_error(self, code, "cannot use bound with TypeVarTuple")
code = """type X[*A: (int, str)] = int"""
check_syntax_error(self, code, "cannot use constraints with TypeVarTuple")
def test_typevartuple_02(self):
def func1[*A]():
return A
a = func1()
self.assertIsInstance(a, TypeVarTuple)
class TypeParamsTypeVarParamSpecTest(unittest.TestCase):
def test_paramspec_01(self):
code = """def func1[**A: str](): pass"""
check_syntax_error(self, code, "cannot use bound with ParamSpec")
code = """def func1[**A: (int, str)](): pass"""
check_syntax_error(self, code, "cannot use constraints with ParamSpec")
code = """class X[**A: str]: pass"""
check_syntax_error(self, code, "cannot use bound with ParamSpec")
code = """class X[**A: (int, str)]: pass"""
check_syntax_error(self, code, "cannot use constraints with ParamSpec")
code = """type X[**A: str] = int"""
check_syntax_error(self, code, "cannot use bound with ParamSpec")
code = """type X[**A: (int, str)] = int"""
check_syntax_error(self, code, "cannot use constraints with ParamSpec")
def test_paramspec_02(self):
def func1[**A]():
return A
a = func1()
self.assertIsInstance(a, ParamSpec)
self.assertTrue(a.__infer_variance__)
self.assertFalse(a.__covariant__)
self.assertFalse(a.__contravariant__)
class TypeParamsTypeParamsDunder(unittest.TestCase):
def test_typeparams_dunder_class_01(self):
class Outer[A, B]:
class Inner[C, D]:
@staticmethod
def get_typeparams():
return A, B, C, D
a, b, c, d = Outer.Inner.get_typeparams()
self.assertEqual(Outer.__type_params__, (a, b))
self.assertEqual(Outer.Inner.__type_params__, (c, d))
self.assertEqual(Outer.__parameters__, (a, b))
self.assertEqual(Outer.Inner.__parameters__, (c, d))
def test_typeparams_dunder_class_02(self):
class ClassA:
pass
self.assertEqual(ClassA.__type_params__, ())
def test_typeparams_dunder_class_03(self):
code = """
class ClassA[A]():
pass
ClassA.__type_params__ = ()
params = ClassA.__type_params__
"""
ns = run_code(code)
self.assertEqual(ns["params"], ())
def test_typeparams_dunder_function_01(self):
def outer[A, B]():
def inner[C, D]():
return A, B, C, D
return inner
inner = outer()
a, b, c, d = inner()
self.assertEqual(outer.__type_params__, (a, b))
self.assertEqual(inner.__type_params__, (c, d))
def test_typeparams_dunder_function_02(self):
def func1():
pass
self.assertEqual(func1.__type_params__, ())
def test_typeparams_dunder_function_03(self):
code = """
def func[A]():
pass
func.__type_params__ = ()
"""
ns = run_code(code)
self.assertEqual(ns["func"].__type_params__, ())
# All these type aliases are used for pickling tests:
T = TypeVar('T')
def func1[X](x: X) -> X: ...
def func2[X, Y](x: X | Y) -> X | Y: ...
def func3[X, *Y, **Z](x: X, y: tuple[*Y], z: Z) -> X: ...
def func4[X: int, Y: (bytes, str)](x: X, y: Y) -> X | Y: ...
class Class1[X]: ...
class Class2[X, Y]: ...
class Class3[X, *Y, **Z]: ...
class Class4[X: int, Y: (bytes, str)]: ...
class TypeParamsPickleTest(unittest.TestCase):
def test_pickling_functions(self):
things_to_test = [
func1,
func2,
func3,
func4,
]
for thing in things_to_test:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(thing=thing, proto=proto):
pickled = pickle.dumps(thing, protocol=proto)
self.assertEqual(pickle.loads(pickled), thing)
def test_pickling_classes(self):
things_to_test = [
Class1,
Class1[int],
Class1[T],
Class2,
Class2[int, T],
Class2[T, int],
Class2[int, str],
Class3,
Class3[int, T, str, bytes, [float, object, T]],
Class4,
Class4[int, bytes],
Class4[T, bytes],
Class4[int, T],
Class4[T, T],
]
for thing in things_to_test:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(thing=thing, proto=proto):
pickled = pickle.dumps(thing, protocol=proto)
self.assertEqual(pickle.loads(pickled), thing)
for klass in things_to_test:
real_class = getattr(klass, '__origin__', klass)
thing = klass()
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(thing=thing, proto=proto):
pickled = pickle.dumps(thing, protocol=proto)
# These instances are not equal,
# but class check is good enough:
self.assertIsInstance(pickle.loads(pickled), real_class)
class TypeParamsWeakRefTest(unittest.TestCase):
def test_weakrefs(self):
T = TypeVar('T')
P = ParamSpec('P')
class OldStyle(Generic[T]):
pass
class NewStyle[T]:
pass
cases = [
T,
TypeVar('T', bound=int),
P,
P.args,
P.kwargs,
TypeVarTuple('Ts'),
OldStyle,
OldStyle[int],
OldStyle(),
NewStyle,
NewStyle[int],
NewStyle(),
Generic[T],
]
for case in cases:
with self.subTest(case=case):
weakref.ref(case)
class TypeParamsRuntimeTest(unittest.TestCase):
def test_name_error(self):
# gh-109118: This crashed the interpreter due to a refcounting bug
code = """
class name_2[name_5]:
class name_4[name_5](name_0):
pass
"""
with self.assertRaises(NameError):
run_code(code)
# Crashed with a slightly different stack trace
code = """
class name_2[name_5]:
class name_4[name_5: name_5](name_0):
pass
"""
with self.assertRaises(NameError):
run_code(code)
def test_broken_class_namespace(self):
code = """
class WeirdMapping(dict):
def __missing__(self, key):
if key == "T":
raise RuntimeError
raise KeyError(key)
class Meta(type):
def __prepare__(name, bases):
return WeirdMapping()
class MyClass[V](metaclass=Meta):
class Inner[U](T):
pass
"""
with self.assertRaises(RuntimeError):
run_code(code)
class DefaultsTest(unittest.TestCase):
def test_defaults_on_func(self):
ns = run_code("""
def func[T=int, **U=float, *V=None]():
pass
""")
T, U, V = ns["func"].__type_params__
self.assertIs(T.__default__, int)
self.assertIs(U.__default__, float)
self.assertIs(V.__default__, None)
def test_defaults_on_class(self):
ns = run_code("""
class C[T=int, **U=float, *V=None]:
pass
""")
T, U, V = ns["C"].__type_params__
self.assertIs(T.__default__, int)
self.assertIs(U.__default__, float)
self.assertIs(V.__default__, None)
def test_defaults_on_type_alias(self):
ns = run_code("""
type Alias[T = int, **U = float, *V = None] = int
""")
T, U, V = ns["Alias"].__type_params__
self.assertIs(T.__default__, int)
self.assertIs(U.__default__, float)
self.assertIs(V.__default__, None)
def test_starred_invalid(self):
check_syntax_error(self, "type Alias[T = *int] = int")
check_syntax_error(self, "type Alias[**P = *int] = int")
def test_starred_typevartuple(self):
ns = run_code("""
default = tuple[int, str]
type Alias[*Ts = *default] = Ts
""")
Ts, = ns["Alias"].__type_params__
self.assertEqual(Ts.__default__, next(iter(ns["default"])))
def test_nondefault_after_default(self):
check_syntax_error(self, "def func[T=int, U](): pass", "non-default type parameter 'U' follows default type parameter")
check_syntax_error(self, "class C[T=int, U]: pass", "non-default type parameter 'U' follows default type parameter")
check_syntax_error(self, "type A[T=int, U] = int", "non-default type parameter 'U' follows default type parameter")
def test_lazy_evaluation(self):
ns = run_code("""
type Alias[T = Undefined, *U = Undefined, **V = Undefined] = int
""")
T, U, V = ns["Alias"].__type_params__
with self.assertRaises(NameError):
T.__default__
with self.assertRaises(NameError):
U.__default__
with self.assertRaises(NameError):
V.__default__
ns["Undefined"] = "defined"
self.assertEqual(T.__default__, "defined")
self.assertEqual(U.__default__, "defined")
self.assertEqual(V.__default__, "defined")
# Now it is cached
ns["Undefined"] = "redefined"
self.assertEqual(T.__default__, "defined")
self.assertEqual(U.__default__, "defined")
self.assertEqual(V.__default__, "defined")
def test_symtable_key_regression_default(self):
# Test against the bugs that would happen if we used .default_
# as the key in the symtable.
ns = run_code("""
type X[T = [T for T in [T]]] = T
""")
T, = ns["X"].__type_params__
self.assertEqual(T.__default__, [T])
def test_symtable_key_regression_name(self):
# Test against the bugs that would happen if we used .name
# as the key in the symtable.
ns = run_code("""
type X1[T = A] = T
type X2[T = B] = T
A = "A"
B = "B"
""")
self.assertEqual(ns["X1"].__type_params__[0].__default__, "A")
self.assertEqual(ns["X2"].__type_params__[0].__default__, "B")