gh-90213: Speed up right shifts of negative integers (GH-30277)

This commit is contained in:
Mark Dickinson 2022-05-02 18:19:03 +01:00 committed by GitHub
parent 4b297a9ffd
commit 0ed91a26fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 30 deletions

View File

@ -985,6 +985,10 @@ class LongTest(unittest.TestCase):
self.assertEqual((-1122) >> 9, -3) self.assertEqual((-1122) >> 9, -3)
self.assertEqual(2**128 >> 9, 2**119) self.assertEqual(2**128 >> 9, 2**119)
self.assertEqual(-2**128 >> 9, -2**119) self.assertEqual(-2**128 >> 9, -2**119)
# Exercise corner case of the current algorithm, where the result of
# shifting a two-limb int by the limb size still has two limbs.
self.assertEqual((1 - BASE*BASE) >> SHIFT, -BASE)
self.assertEqual((BASE - 1 - BASE*BASE) >> SHIFT, -BASE)
def test_big_rshift(self): def test_big_rshift(self):
self.assertEqual(42 >> 32, 0) self.assertEqual(42 >> 32, 0)

View File

@ -0,0 +1,2 @@
Speed up right shift of negative integers, by removing unnecessary creation
of temporaries. Original patch by Xinhang Xu, reworked by Mark Dickinson.

View File

@ -4688,13 +4688,23 @@ divmod_shift(PyObject *shiftby, Py_ssize_t *wordshift, digit *remshift)
return 0; return 0;
} }
/* Inner function for both long_rshift and _PyLong_Rshift, shifting an
integer right by PyLong_SHIFT*wordshift + remshift bits.
wordshift should be nonnegative. */
static PyObject * static PyObject *
long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift) long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
{ {
PyLongObject *z = NULL; PyLongObject *z = NULL;
Py_ssize_t newsize, hishift, i, j; Py_ssize_t newsize, hishift, size_a;
twodigits accum; twodigits accum;
int a_negative;
/* Total number of bits shifted must be nonnegative. */
assert(wordshift >= 0);
assert(remshift < PyLong_SHIFT);
/* Fast path for small a. */
if (IS_MEDIUM_VALUE(a)) { if (IS_MEDIUM_VALUE(a)) {
stwodigits m, x; stwodigits m, x;
digit shift; digit shift;
@ -4704,37 +4714,67 @@ long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
return _PyLong_FromSTwoDigits(x); return _PyLong_FromSTwoDigits(x);
} }
if (Py_SIZE(a) < 0) { a_negative = Py_SIZE(a) < 0;
/* Right shifting negative numbers is harder */ size_a = Py_ABS(Py_SIZE(a));
PyLongObject *a1, *a2;
a1 = (PyLongObject *) long_invert(a); if (a_negative) {
if (a1 == NULL) /* For negative 'a', adjust so that 0 < remshift <= PyLong_SHIFT,
return NULL; while keeping PyLong_SHIFT*wordshift + remshift the same. This
a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift); ensures that 'newsize' is computed correctly below. */
Py_DECREF(a1); if (remshift == 0) {
if (a2 == NULL) if (wordshift == 0) {
return NULL; /* Can only happen if the original shift was 0. */
z = (PyLongObject *) long_invert(a2); return long_long((PyObject *)a);
Py_DECREF(a2); }
remshift = PyLong_SHIFT;
--wordshift;
}
}
assert(wordshift >= 0);
newsize = size_a - wordshift;
if (newsize <= 0) {
/* Shifting all the bits of 'a' out gives either -1 or 0. */
return PyLong_FromLong(-a_negative);
} }
else {
newsize = Py_SIZE(a) - wordshift;
if (newsize <= 0)
return PyLong_FromLong(0);
hishift = PyLong_SHIFT - remshift;
z = _PyLong_New(newsize); z = _PyLong_New(newsize);
if (z == NULL) if (z == NULL) {
return NULL; return NULL;
j = wordshift; }
accum = a->ob_digit[j++] >> remshift; hishift = PyLong_SHIFT - remshift;
for (i = 0; j < Py_SIZE(a); i++, j++) {
accum |= (twodigits)a->ob_digit[j] << hishift; accum = a->ob_digit[wordshift];
if (a_negative) {
/*
For a positive integer a and nonnegative shift, we have:
(-a) >> shift == -((a + 2**shift - 1) >> shift).
In the addition `a + (2**shift - 1)`, the low `wordshift` digits of
`2**shift - 1` all have value `PyLong_MASK`, so we get a carry out
from the bottom `wordshift` digits when at least one of the least
significant `wordshift` digits of `a` is nonzero. Digit `wordshift`
of `2**shift - 1` has value `PyLong_MASK >> hishift`.
*/
Py_SET_SIZE(z, -newsize);
digit sticky = 0;
for (Py_ssize_t j = 0; j < wordshift; j++) {
sticky |= a->ob_digit[j];
}
accum += (PyLong_MASK >> hishift) + (digit)(sticky != 0);
}
accum >>= remshift;
for (Py_ssize_t i = 0, j = wordshift + 1; j < size_a; i++, j++) {
accum += (twodigits)a->ob_digit[j] << hishift;
z->ob_digit[i] = (digit)(accum & PyLong_MASK); z->ob_digit[i] = (digit)(accum & PyLong_MASK);
accum >>= PyLong_SHIFT; accum >>= PyLong_SHIFT;
} }
z->ob_digit[i] = (digit)accum; assert(accum <= PyLong_MASK);
z->ob_digit[newsize - 1] = (digit)accum;
z = maybe_small_long(long_normalize(z)); z = maybe_small_long(long_normalize(z));
}
return (PyObject *)z; return (PyObject *)z;
} }