Skip to content

Commit aa1324c

Browse files
committed
Make 'steal' command atomic
Either unschedule all requested tests, or none if it's not possible - if some of the requested tests have already been processed by the time the request arrives. It may happen if the worker runs tests faster than the controller receives and processes status updates. But in this case maybe it's just better to let the worker keep running. This is a prerequisite for group/scope support in worksteal scheduler - so they won't be broken up incorrectly. This change could break schedulers that use "steal" command. However: 1) worksteal scheduler doesn't need any adjustments. 2) I'm not aware of any external schedulers relying on this command yet. So I think it's better to keep the protocol simple, not complicate it for imaginary compatibility with some unknown and likely non-existent schedulers.
1 parent 9c24f0f commit aa1324c

File tree

3 files changed

+58
-31
lines changed

3 files changed

+58
-31
lines changed

changelog/1144.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make "steal" command atomic - make it unschedule either all requested tests or none.

src/xdist/remote.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
from __future__ import annotations
1010

11+
import collections
1112
import contextlib
1213
import enum
1314
import os
1415
import sys
1516
import time
1617
from typing import Any
1718
from typing import Generator
19+
from typing import Iterable
1820
from typing import Literal
1921
from typing import Sequence
2022
from typing import TypedDict
@@ -66,7 +68,44 @@ def worker_title(title: str) -> None:
6668

6769
class Marker(enum.Enum):
6870
SHUTDOWN = 0
69-
QUEUE_REPLACED = 1
71+
72+
73+
class TestQueue:
74+
"""A simple queue that can be inspected and modified while the lock is held."""
75+
76+
Item = int | Literal[Marker.SHUTDOWN]
77+
78+
def __init__(self, execmodel: execnet.gateway_base.ExecModel):
79+
self._items: collections.deque[TestQueue.Item] = collections.deque()
80+
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
81+
self._has_items_event = execmodel.Event()
82+
83+
def get(self) -> Item:
84+
while True:
85+
with self.lock() as locked_items:
86+
if locked_items:
87+
return locked_items.popleft()
88+
89+
self._has_items_event.wait()
90+
91+
def put(self, item: Item) -> None:
92+
with self.lock() as locked_items:
93+
locked_items.append(item)
94+
95+
def replace(self, iterable: Iterable[Item]) -> None:
96+
with self.lock():
97+
self._items = collections.deque(iterable)
98+
99+
@contextlib.contextmanager
100+
def lock(self) -> Generator[collections.deque[Item], None, None]:
101+
with self._lock:
102+
try:
103+
yield self._items
104+
finally:
105+
if self._items:
106+
self._has_items_event.set()
107+
else:
108+
self._has_items_event.clear()
70109

71110

72111
class WorkerInteractor:
@@ -77,22 +116,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77116
self.testrunuid = workerinput["testrunuid"]
78117
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
79118
self.channel = channel
80-
self.torun = self._make_queue()
119+
self.torun = TestQueue(self.channel.gateway.execmodel)
81120
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
82121
config.pluginmanager.register(self)
83122

84-
def _make_queue(self) -> Any:
85-
return self.channel.gateway.execmodel.queue.Queue()
86-
87-
def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
88-
"""Gets the next item from test queue. Handles the case when the queue
89-
is replaced concurrently in another thread.
90-
"""
91-
result = self.torun.get()
92-
while result is Marker.QUEUE_REPLACED:
93-
result = self.torun.get()
94-
return result # type: ignore[no-any-return]
95-
96123
def sendevent(self, name: str, **kwargs: object) -> None:
97124
self.log("sending", name, kwargs)
98125
self.channel.send((name, kwargs))
@@ -146,30 +173,25 @@ def handle_command(
146173
self.steal(kwargs["indices"])
147174

148175
def steal(self, indices: Sequence[int]) -> None:
149-
indices_set = set(indices)
150-
stolen = []
151-
152-
old_queue, self.torun = self.torun, self._make_queue()
153-
154-
def old_queue_get_nowait_noraise() -> int | None:
155-
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
156-
return old_queue.get_nowait() # type: ignore[no-any-return]
157-
return None
158-
159-
for i in iter(old_queue_get_nowait_noraise, None):
160-
if i in indices_set:
161-
stolen.append(i)
176+
with self.torun.lock() as locked_queue:
177+
requested_set = set(indices)
178+
stolen = list(item for item in locked_queue if item in requested_set)
179+
180+
# Stealing only if all requested tests are still pending
181+
if len(stolen) == len(requested_set):
182+
self.torun.replace(
183+
item for item in locked_queue if item not in requested_set
184+
)
162185
else:
163-
self.torun.put(i)
186+
stolen = []
164187

165188
self.sendevent("unscheduled", indices=stolen)
166-
old_queue.put(Marker.QUEUE_REPLACED)
167189

168190
@pytest.hookimpl
169191
def pytest_runtestloop(self, session: pytest.Session) -> bool:
170192
self.log("entering main loop")
171193
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
172-
self.nextitem_index = self._get_next_item_index()
194+
self.nextitem_index = self.torun.get()
173195
while self.nextitem_index is not Marker.SHUTDOWN:
174196
self.run_one_test()
175197
if session.shouldfail or session.shouldstop:
@@ -179,7 +201,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179201
def run_one_test(self) -> None:
180202
assert isinstance(self.nextitem_index, int)
181203
self.item_index = self.nextitem_index
182-
self.nextitem_index = self._get_next_item_index()
204+
self.nextitem_index = self.torun.get()
183205

184206
items = self.session.items
185207
item = items[self.item_index]

testing/test_remote.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def test_func4(): pass
267267

268268
worker.sendcommand("steal", indices=[1, 2])
269269
ev = worker.popevent("unscheduled")
270+
assert ev.kwargs["indices"] == []
271+
272+
worker.sendcommand("steal", indices=[2])
273+
ev = worker.popevent("unscheduled")
270274
assert ev.kwargs["indices"] == [2]
271275

272276
reports = [

0 commit comments

Comments
 (0)