gh-105730: support more callables in ExceptionGroup.split() and subgroup() (#106035)

This commit is contained in:
Irit Katriel 2023-06-23 19:47:47 +01:00 committed by GitHub
parent 1d33d53780
commit d8ca5a11bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 26 deletions

View File

@ -912,10 +912,11 @@ their subgroups based on the types of the contained exceptions.
Returns an exception group that contains only the exceptions from the
current group that match *condition*, or ``None`` if the result is empty.
The condition can be either a function that accepts an exception and returns
true for those that should be in the subgroup, or it can be an exception type
or a tuple of exception types, which is used to check for a match using the
same check that is used in an ``except`` clause.
The condition can be an exception type or tuple of exception types, in which
case each exception is checked for a match using the same check that is used
in an ``except`` clause. The condition can also be a callable (other than
a type object) that accepts an exception as its single argument and returns
true for the exceptions that should be in the subgroup.
The nesting structure of the current exception is preserved in the result,
as are the values of its :attr:`message`, :attr:`__traceback__`,
@ -926,6 +927,9 @@ their subgroups based on the types of the contained exceptions.
including the top-level and any nested exception groups. If the condition is
true for such an exception group, it is included in the result in full.
.. versionadded:: 3.13
``condition`` can be any callable which is not a type object.
.. method:: split(condition)
Like :meth:`subgroup`, but returns the pair ``(match, rest)`` where ``match``

View File

@ -294,6 +294,15 @@ class ExceptionGroupTestBase(unittest.TestCase):
self.assertEqual(type(exc), type(template))
self.assertEqual(exc.args, template.args)
class Predicate:
def __init__(self, func):
self.func = func
def __call__(self, e):
return self.func(e)
def method(self, e):
return self.func(e)
class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
def setUp(self):
@ -301,10 +310,15 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
def test_basics_subgroup_split__bad_arg_type(self):
class C:
pass
bad_args = ["bad arg",
C,
OSError('instance not type'),
[OSError, TypeError],
(OSError, 42)]
(OSError, 42),
]
for arg in bad_args:
with self.assertRaises(TypeError):
self.eg.subgroup(arg)
@ -336,10 +350,14 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
self.assertMatchesTemplate(subeg, ExceptionGroup, template)
def test_basics_subgroup_by_predicate__passthrough(self):
self.assertIs(self.eg, self.eg.subgroup(lambda e: True))
f = lambda e: True
for callable in [f, Predicate(f), Predicate(f).method]:
self.assertIs(self.eg, self.eg.subgroup(callable))
def test_basics_subgroup_by_predicate__no_match(self):
self.assertIsNone(self.eg.subgroup(lambda e: False))
f = lambda e: False
for callable in [f, Predicate(f), Predicate(f).method]:
self.assertIsNone(self.eg.subgroup(callable))
def test_basics_subgroup_by_predicate__match(self):
eg = self.eg
@ -350,9 +368,12 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
((ValueError, TypeError), self.eg_template)]
for match_type, template in testcases:
subeg = eg.subgroup(lambda e: isinstance(e, match_type))
self.assertEqual(subeg.message, eg.message)
self.assertMatchesTemplate(subeg, ExceptionGroup, template)
f = lambda e: isinstance(e, match_type)
for callable in [f, Predicate(f), Predicate(f).method]:
with self.subTest(callable=callable):
subeg = eg.subgroup(f)
self.assertEqual(subeg.message, eg.message)
self.assertMatchesTemplate(subeg, ExceptionGroup, template)
class ExceptionGroupSplitTests(ExceptionGroupTestBase):
@ -399,14 +420,18 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase):
self.assertIsNone(rest)
def test_basics_split_by_predicate__passthrough(self):
match, rest = self.eg.split(lambda e: True)
self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
self.assertIsNone(rest)
f = lambda e: True
for callable in [f, Predicate(f), Predicate(f).method]:
match, rest = self.eg.split(callable)
self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
self.assertIsNone(rest)
def test_basics_split_by_predicate__no_match(self):
match, rest = self.eg.split(lambda e: False)
self.assertIsNone(match)
self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
f = lambda e: False
for callable in [f, Predicate(f), Predicate(f).method]:
match, rest = self.eg.split(callable)
self.assertIsNone(match)
self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
def test_basics_split_by_predicate__match(self):
eg = self.eg
@ -420,14 +445,16 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase):
]
for match_type, match_template, rest_template in testcases:
match, rest = eg.split(lambda e: isinstance(e, match_type))
self.assertEqual(match.message, eg.message)
self.assertMatchesTemplate(
match, ExceptionGroup, match_template)
if rest_template is not None:
self.assertEqual(rest.message, eg.message)
f = lambda e: isinstance(e, match_type)
for callable in [f, Predicate(f), Predicate(f).method]:
match, rest = eg.split(callable)
self.assertEqual(match.message, eg.message)
self.assertMatchesTemplate(
rest, ExceptionGroup, rest_template)
match, ExceptionGroup, match_template)
if rest_template is not None:
self.assertEqual(rest.message, eg.message)
self.assertMatchesTemplate(
rest, ExceptionGroup, rest_template)
class DeepRecursionInSplitAndSubgroup(unittest.TestCase):

View File

@ -0,0 +1,2 @@
Allow any callable other than type objects as the condition predicate in
:meth:`BaseExceptionGroup.split` and :meth:`BaseExceptionGroup.subgroup`.

View File

@ -992,7 +992,7 @@ get_matcher_type(PyObject *value,
{
assert(value);
if (PyFunction_Check(value)) {
if (PyCallable_Check(value) && !PyType_Check(value)) {
*type = EXCEPTION_GROUP_MATCH_BY_PREDICATE;
return 0;
}
@ -1016,7 +1016,7 @@ get_matcher_type(PyObject *value,
error:
PyErr_SetString(
PyExc_TypeError,
"expected a function, exception type or tuple of exception types");
"expected an exception type, a tuple of exception types, or a callable (other than a class)");
return -1;
}
@ -1032,7 +1032,7 @@ exceptiongroup_split_check_match(PyObject *exc,
return PyErr_GivenExceptionMatches(exc, matcher_value);
}
case EXCEPTION_GROUP_MATCH_BY_PREDICATE: {
assert(PyFunction_Check(matcher_value));
assert(PyCallable_Check(matcher_value) && !PyType_Check(matcher_value));
PyObject *exc_matches = PyObject_CallOneArg(matcher_value, exc);
if (exc_matches == NULL) {
return -1;