diff --git a/Doc/library/exceptions.rst b/Doc/library/exceptions.rst index 4c84e5f8554..8e574b8334e 100644 --- a/Doc/library/exceptions.rst +++ b/Doc/library/exceptions.rst @@ -912,10 +912,11 @@ their subgroups based on the types of the contained exceptions. Returns an exception group that contains only the exceptions from the current group that match *condition*, or ``None`` if the result is empty. - The condition can be either a function that accepts an exception and returns - true for those that should be in the subgroup, or it can be an exception type - or a tuple of exception types, which is used to check for a match using the - same check that is used in an ``except`` clause. + The condition can be an exception type or tuple of exception types, in which + case each exception is checked for a match using the same check that is used + in an ``except`` clause. The condition can also be a callable (other than + a type object) that accepts an exception as its single argument and returns + true for the exceptions that should be in the subgroup. The nesting structure of the current exception is preserved in the result, as are the values of its :attr:`message`, :attr:`__traceback__`, @@ -926,6 +927,9 @@ their subgroups based on the types of the contained exceptions. including the top-level and any nested exception groups. If the condition is true for such an exception group, it is included in the result in full. + .. versionadded:: 3.13 + ``condition`` can be any callable which is not a type object. + .. method:: split(condition) Like :meth:`subgroup`, but returns the pair ``(match, rest)`` where ``match`` diff --git a/Lib/test/test_exception_group.py b/Lib/test/test_exception_group.py index fa159a76ec1..2658e027ff3 100644 --- a/Lib/test/test_exception_group.py +++ b/Lib/test/test_exception_group.py @@ -294,6 +294,15 @@ class ExceptionGroupTestBase(unittest.TestCase): self.assertEqual(type(exc), type(template)) self.assertEqual(exc.args, template.args) +class Predicate: + def __init__(self, func): + self.func = func + + def __call__(self, e): + return self.func(e) + + def method(self, e): + return self.func(e) class ExceptionGroupSubgroupTests(ExceptionGroupTestBase): def setUp(self): @@ -301,10 +310,15 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase): self.eg_template = [ValueError(1), TypeError(int), ValueError(2)] def test_basics_subgroup_split__bad_arg_type(self): + class C: + pass + bad_args = ["bad arg", + C, OSError('instance not type'), [OSError, TypeError], - (OSError, 42)] + (OSError, 42), + ] for arg in bad_args: with self.assertRaises(TypeError): self.eg.subgroup(arg) @@ -336,10 +350,14 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase): self.assertMatchesTemplate(subeg, ExceptionGroup, template) def test_basics_subgroup_by_predicate__passthrough(self): - self.assertIs(self.eg, self.eg.subgroup(lambda e: True)) + f = lambda e: True + for callable in [f, Predicate(f), Predicate(f).method]: + self.assertIs(self.eg, self.eg.subgroup(callable)) def test_basics_subgroup_by_predicate__no_match(self): - self.assertIsNone(self.eg.subgroup(lambda e: False)) + f = lambda e: False + for callable in [f, Predicate(f), Predicate(f).method]: + self.assertIsNone(self.eg.subgroup(callable)) def test_basics_subgroup_by_predicate__match(self): eg = self.eg @@ -350,9 +368,12 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase): ((ValueError, TypeError), self.eg_template)] for match_type, template in testcases: - subeg = eg.subgroup(lambda e: isinstance(e, match_type)) - self.assertEqual(subeg.message, eg.message) - self.assertMatchesTemplate(subeg, ExceptionGroup, template) + f = lambda e: isinstance(e, match_type) + for callable in [f, Predicate(f), Predicate(f).method]: + with self.subTest(callable=callable): + subeg = eg.subgroup(f) + self.assertEqual(subeg.message, eg.message) + self.assertMatchesTemplate(subeg, ExceptionGroup, template) class ExceptionGroupSplitTests(ExceptionGroupTestBase): @@ -399,14 +420,18 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase): self.assertIsNone(rest) def test_basics_split_by_predicate__passthrough(self): - match, rest = self.eg.split(lambda e: True) - self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template) - self.assertIsNone(rest) + f = lambda e: True + for callable in [f, Predicate(f), Predicate(f).method]: + match, rest = self.eg.split(callable) + self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template) + self.assertIsNone(rest) def test_basics_split_by_predicate__no_match(self): - match, rest = self.eg.split(lambda e: False) - self.assertIsNone(match) - self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template) + f = lambda e: False + for callable in [f, Predicate(f), Predicate(f).method]: + match, rest = self.eg.split(callable) + self.assertIsNone(match) + self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template) def test_basics_split_by_predicate__match(self): eg = self.eg @@ -420,14 +445,16 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase): ] for match_type, match_template, rest_template in testcases: - match, rest = eg.split(lambda e: isinstance(e, match_type)) - self.assertEqual(match.message, eg.message) - self.assertMatchesTemplate( - match, ExceptionGroup, match_template) - if rest_template is not None: - self.assertEqual(rest.message, eg.message) + f = lambda e: isinstance(e, match_type) + for callable in [f, Predicate(f), Predicate(f).method]: + match, rest = eg.split(callable) + self.assertEqual(match.message, eg.message) self.assertMatchesTemplate( - rest, ExceptionGroup, rest_template) + match, ExceptionGroup, match_template) + if rest_template is not None: + self.assertEqual(rest.message, eg.message) + self.assertMatchesTemplate( + rest, ExceptionGroup, rest_template) class DeepRecursionInSplitAndSubgroup(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst b/Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst new file mode 100644 index 00000000000..fa70ee09ce2 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst @@ -0,0 +1,2 @@ +Allow any callable other than type objects as the condition predicate in +:meth:`BaseExceptionGroup.split` and :meth:`BaseExceptionGroup.subgroup`. diff --git a/Objects/exceptions.c b/Objects/exceptions.c index 04ea22c2902..f27e6f6c143 100644 --- a/Objects/exceptions.c +++ b/Objects/exceptions.c @@ -992,7 +992,7 @@ get_matcher_type(PyObject *value, { assert(value); - if (PyFunction_Check(value)) { + if (PyCallable_Check(value) && !PyType_Check(value)) { *type = EXCEPTION_GROUP_MATCH_BY_PREDICATE; return 0; } @@ -1016,7 +1016,7 @@ get_matcher_type(PyObject *value, error: PyErr_SetString( PyExc_TypeError, - "expected a function, exception type or tuple of exception types"); + "expected an exception type, a tuple of exception types, or a callable (other than a class)"); return -1; } @@ -1032,7 +1032,7 @@ exceptiongroup_split_check_match(PyObject *exc, return PyErr_GivenExceptionMatches(exc, matcher_value); } case EXCEPTION_GROUP_MATCH_BY_PREDICATE: { - assert(PyFunction_Check(matcher_value)); + assert(PyCallable_Check(matcher_value) && !PyType_Check(matcher_value)); PyObject *exc_matches = PyObject_CallOneArg(matcher_value, exc); if (exc_matches == NULL) { return -1;