Skip to content

Make 'steal' command atomic #1144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/1144.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The internal `steal` command is now atomic - it unschedules either all requested tests or none.

This is a prerequisite for group/scope support in the `worksteal` scheduler, so test groups won't be broken up incorrectly.
88 changes: 60 additions & 28 deletions src/xdist/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

from __future__ import annotations

import collections
import contextlib
import enum
import os
import sys
import time
from typing import Any
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import TypedDict
from typing import Union
import warnings

from _pytest.config import _prepareconfig
Expand Down Expand Up @@ -66,7 +69,44 @@ def worker_title(title: str) -> None:

class Marker(enum.Enum):
SHUTDOWN = 0
QUEUE_REPLACED = 1


class TestQueue:
"""A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method."""

Item = Union[int, Literal[Marker.SHUTDOWN]]

def __init__(self, execmodel: execnet.gateway_base.ExecModel):
self._items: collections.deque[TestQueue.Item] = collections.deque()
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
self._has_items_event = execmodel.Event()

def get(self) -> Item:
while True:
with self.lock() as locked_items:
if locked_items:
return locked_items.popleft()

self._has_items_event.wait()

def put(self, item: Item) -> None:
with self.lock() as locked_items:
locked_items.append(item)

def replace(self, iterable: Iterable[Item]) -> None:
with self.lock():
self._items = collections.deque(iterable)

@contextlib.contextmanager
def lock(self) -> Generator[collections.deque[Item], None, None]:
with self._lock:
try:
yield self._items
finally:
if self._items:
self._has_items_event.set()
else:
self._has_items_event.clear()


class WorkerInteractor:
Expand All @@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
self.testrunuid = workerinput["testrunuid"]
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
self.channel = channel
self.torun = self._make_queue()
self.torun = TestQueue(self.channel.gateway.execmodel)
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
config.pluginmanager.register(self)

def _make_queue(self) -> Any:
return self.channel.gateway.execmodel.queue.Queue()

def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
"""Gets the next item from test queue. Handles the case when the queue
is replaced concurrently in another thread.
"""
result = self.torun.get()
while result is Marker.QUEUE_REPLACED:
result = self.torun.get()
return result # type: ignore[no-any-return]

def sendevent(self, name: str, **kwargs: object) -> None:
self.log("sending", name, kwargs)
self.channel.send((name, kwargs))
Expand Down Expand Up @@ -146,30 +174,34 @@ def handle_command(
self.steal(kwargs["indices"])

def steal(self, indices: Sequence[int]) -> None:
indices_set = set(indices)
stolen = []
"""
Remove tests from the queue.
old_queue, self.torun = self.torun, self._make_queue()
Removes either all requested tests, or none, if some of these tests
are not in the queue (for example, if they were processed already).
def old_queue_get_nowait_noraise() -> int | None:
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
return old_queue.get_nowait() # type: ignore[no-any-return]
return None
:param indices: indices of the tests to remove.
"""
requested_set = set(indices)

with self.torun.lock() as locked_queue:
stolen = list(item for item in locked_queue if item in requested_set)

for i in iter(old_queue_get_nowait_noraise, None):
if i in indices_set:
stolen.append(i)
# Stealing only if all requested tests are still pending
if len(stolen) == len(requested_set):
self.torun.replace(
item for item in locked_queue if item not in requested_set
)
else:
self.torun.put(i)
stolen = []

self.sendevent("unscheduled", indices=stolen)
old_queue.put(Marker.QUEUE_REPLACED)

@pytest.hookimpl
def pytest_runtestloop(self, session: pytest.Session) -> bool:
self.log("entering main loop")
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()
while self.nextitem_index is not Marker.SHUTDOWN:
self.run_one_test()
if session.shouldfail or session.shouldstop:
Expand All @@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
def run_one_test(self) -> None:
assert isinstance(self.nextitem_index, int)
self.item_index = self.nextitem_index
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()

items = self.session.items
item = items[self.item_index]
Expand Down
6 changes: 6 additions & 0 deletions testing/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ def test_func4(): pass

worker.sendcommand("steal", indices=[1, 2])
ev = worker.popevent("unscheduled")
# Cannot steal index 1 because it is completed already, so do not steal any.
assert ev.kwargs["indices"] == []

# Index 2 can be stolen, as it is still pending.
worker.sendcommand("steal", indices=[2])
ev = worker.popevent("unscheduled")
assert ev.kwargs["indices"] == [2]

reports = [
Expand Down