Skip to content

Commit 0533c1f

Browse files
authored
gh-123471: Make itertools.chain thread-safe (#135689)
1 parent 536a5ff commit 0533c1f

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

Lib/test/test_free_threading/test_itertools.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
from threading import Thread, Barrier
3-
from itertools import batched, cycle
3+
from itertools import batched, chain, cycle
44
from test.support import threading_helper
55

66

@@ -17,7 +17,7 @@ def work(it):
1717
barrier.wait()
1818
while True:
1919
try:
20-
_ = next(it)
20+
next(it)
2121
except StopIteration:
2222
break
2323

@@ -62,6 +62,34 @@ def work(it):
6262

6363
barrier.reset()
6464

65+
@threading_helper.reap_threads
66+
def test_chain(self):
67+
number_of_threads = 6
68+
number_of_iterations = 20
69+
70+
barrier = Barrier(number_of_threads)
71+
def work(it):
72+
barrier.wait()
73+
while True:
74+
try:
75+
next(it)
76+
except StopIteration:
77+
break
78+
79+
data = [(1, )] * 200
80+
for it in range(number_of_iterations):
81+
chain_iterator = chain(*data)
82+
worker_threads = []
83+
for ii in range(number_of_threads):
84+
worker_threads.append(
85+
Thread(target=work, args=[chain_iterator]))
86+
87+
with threading_helper.start_threads(worker_threads):
88+
pass
89+
90+
barrier.reset()
91+
92+
6593

6694
if __name__ == "__main__":
6795
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make concurrent iterations over :class:`itertools.chain` safe under :term:`free threading`.

Modules/itertoolsmodule.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,8 +1880,8 @@ chain_traverse(PyObject *op, visitproc visit, void *arg)
18801880
return 0;
18811881
}
18821882

1883-
static PyObject *
1884-
chain_next(PyObject *op)
1883+
static inline PyObject *
1884+
chain_next_lock_held(PyObject *op)
18851885
{
18861886
chainobject *lz = chainobject_CAST(op);
18871887
PyObject *item;
@@ -1919,6 +1919,16 @@ chain_next(PyObject *op)
19191919
return NULL;
19201920
}
19211921

1922+
static PyObject *
1923+
chain_next(PyObject *op)
1924+
{
1925+
PyObject *result;
1926+
Py_BEGIN_CRITICAL_SECTION(op);
1927+
result = chain_next_lock_held(op);
1928+
Py_END_CRITICAL_SECTION()
1929+
return result;
1930+
}
1931+
19221932
PyDoc_STRVAR(chain_doc,
19231933
"chain(*iterables)\n\
19241934
--\n\

0 commit comments

Comments
 (0)