gh-115480: Type propagate _BINARY_OP_ADD_UNICODE (GH-115710)

This commit is contained in:
Ken Jin 2024-03-02 03:40:04 +08:00 committed by GitHub
parent b5949eac62
commit ff96b81d78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 66 additions and 8 deletions

View File

@ -795,11 +795,14 @@ class TestUopsOptimization(unittest.TestCase):
def testfunc(n):
a = 1.0
for _ in range(n):
a = a + 0.1
a = a + 0.25
a = a + 0.25
a = a + 0.25
a = a + 0.25
return a
res, ex = self._run_with_optimizer(testfunc, 32)
self.assertAlmostEqual(res, 4.2)
self.assertAlmostEqual(res, 33.0)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
guard_both_float_count = [opname for opname in iter_opnames(ex) if opname == "_GUARD_BOTH_FLOAT"]
@ -812,11 +815,14 @@ class TestUopsOptimization(unittest.TestCase):
def testfunc(n):
a = 1.0
for _ in range(n):
a = a - 0.1
a = a - 0.25
a = a - 0.25
a = a - 0.25
a = a - 0.25
return a
res, ex = self._run_with_optimizer(testfunc, 32)
self.assertAlmostEqual(res, -2.2)
self.assertAlmostEqual(res, -31.0)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
guard_both_float_count = [opname for opname in iter_opnames(ex) if opname == "_GUARD_BOTH_FLOAT"]
@ -829,11 +835,14 @@ class TestUopsOptimization(unittest.TestCase):
def testfunc(n):
a = 1.0
for _ in range(n):
a = a * 2.0
a = a * 1.0
a = a * 1.0
a = a * 1.0
a = a * 1.0
return a
res, ex = self._run_with_optimizer(testfunc, 32)
self.assertAlmostEqual(res, 2 ** 32)
self.assertAlmostEqual(res, 1.0)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
guard_both_float_count = [opname for opname in iter_opnames(ex) if opname == "_GUARD_BOTH_FLOAT"]
@ -842,6 +851,24 @@ class TestUopsOptimization(unittest.TestCase):
# We'll also need to verify that propagation actually occurs.
self.assertIn("_BINARY_OP_MULTIPLY_FLOAT", uops)
def test_add_unicode_propagation(self):
def testfunc(n):
a = ""
for _ in range(n):
a + a
a + a
a + a
a + a
return a
res, ex = self._run_with_optimizer(testfunc, 32)
self.assertEqual(res, "")
self.assertIsNotNone(ex)
uops = get_opnames(ex)
guard_both_unicode_count = [opname for opname in iter_opnames(ex) if opname == "_GUARD_BOTH_UNICODE"]
self.assertLessEqual(len(guard_both_unicode_count), 1)
self.assertIn("_BINARY_OP_ADD_UNICODE", uops)
def test_compare_op_type_propagation_float(self):
def testfunc(n):
a = 1.0

View File

@ -254,6 +254,22 @@ dummy_func(void) {
}
}
op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
if (sym_is_const(left) && sym_is_const(right) &&
sym_matches_type(left, &PyUnicode_Type) && sym_matches_type(right, &PyUnicode_Type)) {
PyObject *temp = PyUnicode_Concat(sym_get_const(left), sym_get_const(right));
if (temp == NULL) {
goto error;
}
res = sym_new_const(ctx, temp);
Py_DECREF(temp);
OUT_OF_SPACE_IF_NULL(res);
}
else {
OUT_OF_SPACE_IF_NULL(res = sym_new_type(ctx, &PyUnicode_Type));
}
}
op(_TO_BOOL, (value -- res)) {
(void)value;
res = sym_new_type(ctx, &PyBool_Type);

View File

@ -446,9 +446,24 @@
}
case _BINARY_OP_ADD_UNICODE: {
_Py_UopsSymbol *right;
_Py_UopsSymbol *left;
_Py_UopsSymbol *res;
res = sym_new_unknown(ctx);
if (res == NULL) goto out_of_space;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
sym_matches_type(left, &PyUnicode_Type) && sym_matches_type(right, &PyUnicode_Type)) {
PyObject *temp = PyUnicode_Concat(sym_get_const(left), sym_get_const(right));
if (temp == NULL) {
goto error;
}
res = sym_new_const(ctx, temp);
Py_DECREF(temp);
OUT_OF_SPACE_IF_NULL(res);
}
else {
OUT_OF_SPACE_IF_NULL(res = sym_new_type(ctx, &PyUnicode_Type));
}
stack_pointer[-2] = res;
stack_pointer += -1;
break;