GH-116554: Relax list.sort()'s notion of "descending" runs (#116578)

* GH-116554: Relax list.sort()'s notion of "descending" run

Rewrote `count_run()` so that sub-runs of equal elements no longer end a descending run. Both ascending and descending runs can have arbitrarily many sub-runs of arbitrarily many equal elements now. This is tricky, because we only use ``<`` comparisons, so checking for equality doesn't come "for free". Surprisingly, it turned out there's a very cheap (one comparison) way to determine whether an ascending run consisted of all-equal elements. That sealed the deal.

In addition, after a descending run is reversed in-place, we now go on to see whether it can be extended by an ascending run that just happens to be adjacent. This succeeds in finding at least one additional element to append about half the time, and so appears to more than repay its cost (the savings come from getting to skip a binary search, when a short run is artificially forced to length MIINRUN later, for each new element `count_run()` can add to the initial run).

While these have been in the back of my mind for years, a question on StackOverflow pushed it to action:

https://stackoverflow.com/questions/78108792/

They were wondering why it took about 4x longer to sort a list like:

[999_999, 999_999, ..., 2, 2, 1, 1, 0, 0]

than "similar" lists. Of course that runs very much faster after this patch.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
This commit is contained in:
Tim Peters 2024-03-12 19:59:42 -05:00 committed by GitHub
parent 7d1abe9502
commit bf121d6a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 154 additions and 64 deletions

View File

@ -128,6 +128,27 @@ class TestBase(unittest.TestCase):
x = [e for e, i in augmented] # a stable sort of s x = [e for e, i in augmented] # a stable sort of s
check("stability", x, s) check("stability", x, s)
def test_small_stability(self):
from itertools import product
from operator import itemgetter
# Exhaustively test stability across all lists of small lengths
# and only a few distinct elements.
# This can provoke edge cases that randomization is unlikely to find.
# But it can grow very expensive quickly, so don't overdo it.
NELTS = 3
MAXSIZE = 9
pick0 = itemgetter(0)
for length in range(MAXSIZE + 1):
# There are NELTS ** length distinct lists.
for t in product(range(NELTS), repeat=length):
xs = list(zip(t, range(length)))
# Stability forced by index in each element.
forced = sorted(xs)
# Use key= to hide the index from compares.
native = sorted(xs, key=pick0)
self.assertEqual(forced, native)
#============================================================================== #==============================================================================
class TestBugs(unittest.TestCase): class TestBugs(unittest.TestCase):

View File

@ -0,0 +1 @@
``list.sort()`` now exploits more cases of partial ordering, particularly those with long descending runs with sub-runs of equal values. Those are recognized as single runs now (previously, each block of repeated values caused a new run to be created).

View File

@ -1618,10 +1618,11 @@ struct s_MergeState {
/* binarysort is the best method for sorting small arrays: it does /* binarysort is the best method for sorting small arrays: it does
few compares, but can do data movement quadratic in the number of few compares, but can do data movement quadratic in the number of
elements. elements.
[lo, hi) is a contiguous slice of a list, and is sorted via [lo.keys, hi) is a contiguous slice of a list of keys, and is sorted via
binary insertion. This sort is stable. binary insertion. This sort is stable.
On entry, must have lo <= start <= hi, and that [lo, start) is already On entry, must have lo.keys <= start <= hi, and that
sorted (pass start == lo if you don't know!). [lo.keys, start) is already sorted (pass start == lo.keys if you don't
know!).
If islt() complains return -1, else 0. If islt() complains return -1, else 0.
Even in case of error, the output slice will be some permutation of Even in case of error, the output slice will be some permutation of
the input (nothing is lost or duplicated). the input (nothing is lost or duplicated).
@ -1634,7 +1635,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
PyObject *pivot; PyObject *pivot;
assert(lo.keys <= start && start <= hi); assert(lo.keys <= start && start <= hi);
/* assert [lo, start) is sorted */ /* assert [lo.keys, start) is sorted */
if (lo.keys == start) if (lo.keys == start)
++start; ++start;
for (; start < hi; ++start) { for (; start < hi; ++start) {
@ -1643,9 +1644,9 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
r = start; r = start;
pivot = *r; pivot = *r;
/* Invariants: /* Invariants:
* pivot >= all in [lo, l). * pivot >= all in [lo.keys, l).
* pivot < all in [r, start). * pivot < all in [r, start).
* The second is vacuously true at the start. * These are vacuously true at the start.
*/ */
assert(l < r); assert(l < r);
do { do {
@ -1656,7 +1657,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
l = p+1; l = p+1;
} while (l < r); } while (l < r);
assert(l == r); assert(l == r);
/* The invariants still hold, so pivot >= all in [lo, l) and /* The invariants still hold, so pivot >= all in [lo.keys, l) and
pivot < all in [l, start), so pivot belongs at l. Note pivot < all in [l, start), so pivot belongs at l. Note
that if there are elements equal to pivot, l points to the that if there are elements equal to pivot, l points to the
first slot after them -- that's why this sort is stable. first slot after them -- that's why this sort is stable.
@ -1671,7 +1672,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
p = start + offset; p = start + offset;
pivot = *p; pivot = *p;
l += offset; l += offset;
for (p = start + offset; p > l; --p) for ( ; p > l; --p)
*p = *(p-1); *p = *(p-1);
*l = pivot; *l = pivot;
} }
@ -1682,56 +1683,115 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
return -1; return -1;
} }
static void
sortslice_reverse(sortslice *s, Py_ssize_t n)
{
reverse_slice(s->keys, &s->keys[n]);
if (s->values != NULL)
reverse_slice(s->values, &s->values[n]);
}
/* /*
Return the length of the run beginning at lo, in the slice [lo, hi). lo < hi Return the length of the run beginning at slo->keys, spanning no more than
is required on entry. "A run" is the longest ascending sequence, with nremaining elements. The run beginning there may be ascending or descending,
but the function permutes it in place, if needed, so that it's always ascending
lo[0] <= lo[1] <= lo[2] <= ... upon return.
or the longest descending sequence, with
lo[0] > lo[1] > lo[2] > ...
Boolean *descending is set to 0 in the former case, or to 1 in the latter.
For its intended use in a stable mergesort, the strictness of the defn of
"descending" is needed so that the caller can safely reverse a descending
sequence without violating stability (strict > ensures there are no equal
elements to get out of order).
Returns -1 in case of error. Returns -1 in case of error.
*/ */
static Py_ssize_t static Py_ssize_t
count_run(MergeState *ms, PyObject **lo, PyObject **hi, int *descending) count_run(MergeState *ms, sortslice *slo, Py_ssize_t nremaining)
{ {
Py_ssize_t k; Py_ssize_t k; /* used by IFLT macro expansion */
Py_ssize_t n; Py_ssize_t n;
PyObject ** const lo = slo->keys;
assert(lo < hi); /* In general, as things go on we've established that the slice starts
*descending = 0; with a monotone run of n elements, starting at lo. */
++lo;
if (lo == hi)
return 1;
n = 2; /* We're n elements into the slice, and the most recent neq+1 elments are
IFLT(*lo, *(lo-1)) { * all equal. This reverses them in-place, and resets neq for reuse.
*descending = 1; */
for (lo = lo+1; lo < hi; ++lo, ++n) { #define REVERSE_LAST_NEQ \
IFLT(*lo, *(lo-1)) if (neq) { \
; sortslice slice = *slo; \
else ++neq; \
sortslice_advance(&slice, n - neq); \
sortslice_reverse(&slice, neq); \
neq = 0; \
}
/* Sticking to only __lt__ compares is confusing and error-prone. But in
* this routine, almost all uses of IFLT can be captured by tiny macros
* giving mnemonic names to the intent. Note that inline functions don't
* work for this (IFLT expands to code including `goto fail`).
*/
#define IF_NEXT_LARGER IFLT(lo[n-1], lo[n])
#define IF_NEXT_SMALLER IFLT(lo[n], lo[n-1])
assert(nremaining);
/* try ascending run first */
for (n = 1; n < nremaining; ++n) {
IF_NEXT_SMALLER
break; break;
} }
if (n == nremaining)
return n;
/* lo[n] is strictly less */
/* If n is 1 now, then the first compare established it's a descending
* run, so fall through to the descending case. But if n > 1, there are
* n elements in an ascending run terminated by the strictly less lo[n].
* If the first key < lo[n-1], *somewhere* along the way the sequence
* increased, so we're done (there is no descending run).
* Else first key >= lo[n-1], which implies that the entire ascending run
* consists of equal elements. In that case, this is a descending run,
* and we reverse the all-equal prefix in-place.
*/
if (n > 1) {
IFLT(lo[0], lo[n-1])
return n;
sortslice_reverse(slo, n);
}
++n; /* in all cases it's been established that lo[n] has been resolved */
/* Finish descending run. All-squal subruns are reversed in-place on the
* fly. Their original order will be restored at the end by the whole-slice
* reversal.
*/
Py_ssize_t neq = 0;
for ( ; n < nremaining; ++n) {
IF_NEXT_SMALLER {
/* This ends the most recent run of equal elments, but still in
* the "descending" direction.
*/
REVERSE_LAST_NEQ
} }
else { else {
for (lo = lo+1; lo < hi; ++lo, ++n) { IF_NEXT_LARGER /* descending run is over */
IFLT(*lo, *(lo-1))
break; break;
else /* not x < y and not y < x implies x == y */
++neq;
} }
} }
REVERSE_LAST_NEQ
sortslice_reverse(slo, n); /* transform to ascending run */
/* And after reversing, it's possible this can be extended by a
* naturally increasing suffix; e.g., [3, 2, 3, 4, 1] makes an
* ascending run from the first 4 elements.
*/
for ( ; n < nremaining; ++n) {
IF_NEXT_SMALLER
break;
}
return n; return n;
fail: fail:
return -1; return -1;
#undef REVERSE_LAST_NEQ
#undef IF_NEXT_SMALLER
#undef IF_NEXT_LARGER
} }
/* /*
@ -2449,14 +2509,6 @@ merge_compute_minrun(Py_ssize_t n)
return n + r; return n + r;
} }
static void
reverse_sortslice(sortslice *s, Py_ssize_t n)
{
reverse_slice(s->keys, &s->keys[n]);
if (s->values != NULL)
reverse_slice(s->values, &s->values[n]);
}
/* Here we define custom comparison functions to optimize for the cases one commonly /* Here we define custom comparison functions to optimize for the cases one commonly
* encounters in practice: homogeneous lists, often of one of the basic types. */ * encounters in practice: homogeneous lists, often of one of the basic types. */
@ -2824,15 +2876,12 @@ list_sort_impl(PyListObject *self, PyObject *keyfunc, int reverse)
*/ */
minrun = merge_compute_minrun(nremaining); minrun = merge_compute_minrun(nremaining);
do { do {
int descending;
Py_ssize_t n; Py_ssize_t n;
/* Identify next run. */ /* Identify next run. */
n = count_run(&ms, lo.keys, lo.keys + nremaining, &descending); n = count_run(&ms, &lo, nremaining);
if (n < 0) if (n < 0)
goto fail; goto fail;
if (descending)
reverse_sortslice(&lo, n);
/* If short, extend to min(minrun, nremaining). */ /* If short, extend to min(minrun, nremaining). */
if (n < minrun) { if (n < minrun) {
const Py_ssize_t force = nremaining <= minrun ? const Py_ssize_t force = nremaining <= minrun ?

View File

@ -212,24 +212,43 @@ A detailed description of timsort follows.
Runs Runs
---- ----
count_run() returns the # of elements in the next run. A run is either count_run() returns the # of elements in the next run, and, if it's a
"ascending", which means non-decreasing: descending run, reverses it in-place. A run is either "ascending", which
means non-decreasing:
a0 <= a1 <= a2 <= ... a0 <= a1 <= a2 <= ...
or "descending", which means strictly decreasing: or "descending", which means non-increasing:
a0 > a1 > a2 > ... a0 >= a1 >= a2 >= ...
Note that a run is always at least 2 long, unless we start at the array's Note that a run is always at least 2 long, unless we start at the array's
last element. last element. If all elements in the array are equal, it can be viewed as
both ascending and descending. Upon return, the run count_run() identifies
is always ascending.
The definition of descending is strict, because the main routine reverses Reversal is done via the obvious fast "swap elements starting at each
a descending run in-place, transforming a descending run into an ascending end, and converge at the middle" method. That can violate stability if
run. Reversal is done via the obvious fast "swap elements starting at each the slice contains any equal elements. For that reason, for a long time
end, and converge at the middle" method, and that can violate stability if the code used strict inequality (">" rather than ">=") in its definition
the slice contains any equal elements. Using a strict definition of of descending.
descending ensures that a descending run contains distinct elements.
Removing that restriction required some complication: when processing a
descending run, all-equal sub-runs of elements are reversed in-place, on the
fly. Their original relative order is restored "by magic" via the final
"reverse the entire run" step.
This makes processing descending runs a little more costly. We only use
`__lt__` comparisons, so that `x == y` has to be deduced from
`not x < y and not y < x`. But so long as a run remains strictly decreasing,
only one of those compares needs to be done per loop iteration. So the primsry
extra cost is paid only when there are equal elements, and they get some
compensating benefit by not needing to end the descending run.
There's one more trick added since the original: after reversing a descending
run, it's possible that it can be extended by an adjacent ascending run. For
example, given [3, 2, 1, 3, 4, 5, 0], the 3-element descending prefix is
reversed in-place, and then extended by [3, 4, 5].
If an array is random, it's very unlikely we'll see long runs. If a natural If an array is random, it's very unlikely we'll see long runs. If a natural
run contains less than minrun elements (see next section), the main loop run contains less than minrun elements (see next section), the main loop