import os import unittest class NamedExpressionInvalidTest(unittest.TestCase): def test_named_expression_invalid_01(self): code = """x := 0""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_02(self): code = """x = y := 0""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_03(self): code = """y := f(x)""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_04(self): code = """y0 = y1 := f(x)""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_06(self): code = """((a, b) := (1, 2))""" with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with tuple"): exec(code, {}, {}) def test_named_expression_invalid_07(self): code = """def spam(a = b := 42): pass""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_08(self): code = """def spam(a: b := 42 = 5): pass""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_09(self): code = """spam(a=b := 'c')""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_10(self): code = """spam(x = y := f(x))""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_11(self): code = """spam(a=1, b := 2)""" with self.assertRaisesRegex(SyntaxError, "positional argument follows keyword argument"): exec(code, {}, {}) def test_named_expression_invalid_12(self): code = """spam(a=1, (b := 2))""" with self.assertRaisesRegex(SyntaxError, "positional argument follows keyword argument"): exec(code, {}, {}) def test_named_expression_invalid_13(self): code = """spam(a=1, (b := 2))""" with self.assertRaisesRegex(SyntaxError, "positional argument follows keyword argument"): exec(code, {}, {}) def test_named_expression_invalid_14(self): code = """(x := lambda: y := 1)""" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_15(self): code = """(lambda: x := 1)""" with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with lambda"): exec(code, {}, {}) def test_named_expression_invalid_16(self): code = "[i + 1 for i in i := [1,2]]" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_17(self): code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]" with self.assertRaisesRegex(SyntaxError, "invalid syntax"): exec(code, {}, {}) def test_named_expression_invalid_18(self): code = """class Foo(): [(42, 1 + ((( j := i )))) for i in range(5)] """ with self.assertRaisesRegex(TargetScopeError, "named expression within a comprehension cannot be used in a class body"): exec(code, {}, {}) class NamedExpressionAssignmentTest(unittest.TestCase): def test_named_expression_assignment_01(self): (a := 10) self.assertEqual(a, 10) def test_named_expression_assignment_02(self): a = 20 (a := a) self.assertEqual(a, 20) def test_named_expression_assignment_03(self): (total := 1 + 2) self.assertEqual(total, 3) def test_named_expression_assignment_04(self): (info := (1, 2, 3)) self.assertEqual(info, (1, 2, 3)) def test_named_expression_assignment_05(self): (x := 1, 2) self.assertEqual(x, 1) def test_named_expression_assignment_06(self): (z := (y := (x := 0))) self.assertEqual(x, 0) self.assertEqual(y, 0) self.assertEqual(z, 0) def test_named_expression_assignment_07(self): (loc := (1, 2)) self.assertEqual(loc, (1, 2)) def test_named_expression_assignment_08(self): if spam := "eggs": self.assertEqual(spam, "eggs") else: self.fail("variable was not assigned using named expression") def test_named_expression_assignment_09(self): if True and (spam := True): self.assertTrue(spam) else: self.fail("variable was not assigned using named expression") def test_named_expression_assignment_10(self): if (match := 10) == 10: pass else: self.fail("variable was not assigned using named expression") def test_named_expression_assignment_11(self): def spam(a): return a input_data = [1, 2, 3] res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) def test_named_expression_assignment_12(self): def spam(a): return a res = [[y := spam(x), x/y] for x in range(1, 5)] self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]]) def test_named_expression_assignment_13(self): length = len(lines := [1, 2]) self.assertEqual(length, 2) self.assertEqual(lines, [1,2]) def test_named_expression_assignment_14(self): """ Where all variables are positive integers, and a is at least as large as the n'th root of x, this algorithm returns the floor of the n'th root of x (and roughly doubling the number of accurate bits per iteration): """ a = 9 n = 2 x = 3 while a > (d := x // a**(n-1)): a = ((n-1)*a + d) // n self.assertEqual(a, 1) def test_named_expression_assignment_15(self): while a := False: pass # This will not run self.assertEqual(a, False) def test_named_expression_assignment_16(self): a, b = 1, 2 fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)} self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21}) class NamedExpressionScopeTest(unittest.TestCase): def test_named_expression_scope_01(self): code = """def spam(): (a := 5) print(a)""" with self.assertRaisesRegex(NameError, "name 'a' is not defined"): exec(code, {}, {}) def test_named_expression_scope_02(self): total = 0 partial_sums = [total := total + v for v in range(5)] self.assertEqual(partial_sums, [0, 1, 3, 6, 10]) self.assertEqual(total, 10) def test_named_expression_scope_03(self): containsOne = any((lastNum := num) == 1 for num in [1, 2, 3]) self.assertTrue(containsOne) self.assertEqual(lastNum, 1) def test_named_expression_scope_04(self): def spam(a): return a res = [[y := spam(x), x/y] for x in range(1, 5)] self.assertEqual(y, 4) def test_named_expression_scope_05(self): def spam(a): return a input_data = [1, 2, 3] res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) self.assertEqual(y, 3) def test_named_expression_scope_06(self): res = [[spam := i for i in range(3)] for j in range(2)] self.assertEqual(res, [[0, 1, 2], [0, 1, 2]]) self.assertEqual(spam, 2) def test_named_expression_scope_07(self): len(lines := [1, 2]) self.assertEqual(lines, [1, 2]) def test_named_expression_scope_08(self): def spam(a): return a def eggs(b): return b * 2 res = [spam(a := eggs(b := h)) for h in range(2)] self.assertEqual(res, [0, 2]) self.assertEqual(a, 2) self.assertEqual(b, 1) def test_named_expression_scope_09(self): def spam(a): return a def eggs(b): return b * 2 res = [spam(a := eggs(a := h)) for h in range(2)] self.assertEqual(res, [0, 2]) self.assertEqual(a, 2) def test_named_expression_scope_10(self): res = [b := [a := 1 for i in range(2)] for j in range(2)] self.assertEqual(res, [[1, 1], [1, 1]]) self.assertEqual(a, 1) self.assertEqual(b, [1, 1]) def test_named_expression_scope_11(self): res = [j := i for i in range(5)] self.assertEqual(res, [0, 1, 2, 3, 4]) self.assertEqual(j, 4) def test_named_expression_scope_12(self): res = [i := i for i in range(5)] self.assertEqual(res, [0, 1, 2, 3, 4]) self.assertEqual(i, 4) def test_named_expression_scope_13(self): res = [i := 0 for i, j in [(1, 2), (3, 4)]] self.assertEqual(res, [0, 0]) self.assertEqual(i, 0) def test_named_expression_scope_14(self): res = [(i := 0, j := 1) for i, j in [(1, 2), (3, 4)]] self.assertEqual(res, [(0, 1), (0, 1)]) self.assertEqual(i, 0) self.assertEqual(j, 1) def test_named_expression_scope_15(self): res = [(i := i, j := j) for i, j in [(1, 2), (3, 4)]] self.assertEqual(res, [(1, 2), (3, 4)]) self.assertEqual(i, 3) self.assertEqual(j, 4) def test_named_expression_scope_16(self): res = [(i := j, j := i) for i, j in [(1, 2), (3, 4)]] self.assertEqual(res, [(2, 2), (4, 4)]) self.assertEqual(i, 4) self.assertEqual(j, 4) def test_named_expression_scope_17(self): b = 0 res = [b := i + b for i in range(5)] self.assertEqual(res, [0, 1, 3, 6, 10]) self.assertEqual(b, 10) def test_named_expression_scope_18(self): def spam(a): return a res = spam(b := 2) self.assertEqual(res, 2) self.assertEqual(b, 2) def test_named_expression_scope_19(self): def spam(a): return a res = spam((b := 2)) self.assertEqual(res, 2) self.assertEqual(b, 2) def test_named_expression_scope_20(self): def spam(a): return a res = spam(a=(b := 2)) self.assertEqual(res, 2) self.assertEqual(b, 2) def test_named_expression_scope_21(self): def spam(a, b): return a + b res = spam(c := 2, b=1) self.assertEqual(res, 3) self.assertEqual(c, 2) def test_named_expression_scope_22(self): def spam(a, b): return a + b res = spam((c := 2), b=1) self.assertEqual(res, 3) self.assertEqual(c, 2) def test_named_expression_scope_23(self): def spam(a, b): return a + b res = spam(b=(c := 2), a=1) self.assertEqual(res, 3) self.assertEqual(c, 2) def test_named_expression_scope_24(self): a = 10 def spam(): nonlocal a (a := 20) spam() self.assertEqual(a, 20) def test_named_expression_scope_25(self): ns = {} code = """a = 10 def spam(): global a (a := 20) spam()""" exec(code, ns, {}) self.assertEqual(ns["a"], 20) if __name__ == "__main__": unittest.main()