[3.13] gh-117657: Fix itertools.count thread safety (GH-119268) (#120007)

Fix itertools.count in free-threading mode
(cherry picked from commit 87939bd579)

Co-authored-by: Arnon Yaari <wiggin15@yahoo.com>
This commit is contained in:
Sam Gross 2024-06-03 18:47:34 -04:00 committed by GitHub
parent ae705319fc
commit 79fae3b0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 11 deletions

View File

@ -644,7 +644,7 @@ class TestBasicOps(unittest.TestCase):
count(1, maxsize+5); sys.exc_info() count(1, maxsize+5); sys.exc_info()
@pickle_deprecated @pickle_deprecated
def test_count_with_stride(self): def test_count_with_step(self):
self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)]) self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
self.assertEqual(lzip('abc',count(start=2,step=3)), self.assertEqual(lzip('abc',count(start=2,step=3)),
[('a', 2), ('b', 5), ('c', 8)]) [('a', 2), ('b', 5), ('c', 8)])
@ -699,6 +699,28 @@ class TestBasicOps(unittest.TestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1): for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, count(i, j)) self.pickletest(proto, count(i, j))
@threading_helper.requires_working_threading()
def test_count_threading(self, step=1):
# this test verifies multithreading consistency, which is
# mostly for testing builds without GIL, but nice to test anyway
count_to = 10_000
num_threads = 10
c = count(step=step)
def counting_thread():
for i in range(count_to):
next(c)
threads = []
for i in range(num_threads):
thread = threading.Thread(target=counting_thread)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
self.assertEqual(next(c), count_to * num_threads * step)
def test_count_with_step_threading(self):
self.test_count_threading(step=5)
def test_cycle(self): def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), []) self.assertEqual(list(cycle('')), [])

View File

@ -1,13 +1,14 @@
#include "Python.h" #include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs() #include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin() #include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_long.h" // _PyLong_GetZero() #include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_typeobject.h" // _PyType_GetModuleState() #include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_object.h" // _PyObject_GC_TRACK() #include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_tuple.h" // _PyTuple_ITEMS() #include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()
#include <stddef.h> // offsetof() #include <stddef.h> // offsetof()
/* Itertools module written and maintained /* Itertools module written and maintained
by Raymond D. Hettinger <python@rcn.com> by Raymond D. Hettinger <python@rcn.com>
@ -4037,7 +4038,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1)); assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
Advances with: cnt += 1 Advances with: cnt += 1
When count hits Y_SSIZE_T_MAX, switch to slow_mode. When count hits PY_SSIZE_T_MAX, switch to slow_mode.
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float. slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
static PyObject * static PyObject *
count_next(countobject *lz) count_next(countobject *lz)
{ {
#ifndef Py_GIL_DISABLED
if (lz->cnt == PY_SSIZE_T_MAX) if (lz->cnt == PY_SSIZE_T_MAX)
return count_nextlong(lz); return count_nextlong(lz);
return PyLong_FromSsize_t(lz->cnt++); return PyLong_FromSsize_t(lz->cnt++);
#else
// free-threading version
// fast mode uses compare-exchange loop
// slow mode uses a critical section
PyObject *returned;
Py_ssize_t cnt;
cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
for (;;) {
if (cnt == PY_SSIZE_T_MAX) {
Py_BEGIN_CRITICAL_SECTION(lz);
returned = count_nextlong(lz);
Py_END_CRITICAL_SECTION();
return returned;
}
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
return PyLong_FromSsize_t(cnt);
}
}
#endif
} }
static PyObject * static PyObject *

View File

@ -47,7 +47,6 @@ race_top:_PyImport_AcquireLock
race_top:_Py_dict_lookup_threadsafe race_top:_Py_dict_lookup_threadsafe
race_top:_imp_release_lock race_top:_imp_release_lock
race_top:_multiprocessing_SemLock_acquire_impl race_top:_multiprocessing_SemLock_acquire_impl
race_top:count_next
race_top:dictiter_new race_top:dictiter_new
race_top:dictresize race_top:dictresize
race_top:insert_to_emptydict race_top:insert_to_emptydict