8
8
9
9
from __future__ import annotations
10
10
11
+ import collections
11
12
import contextlib
12
13
import enum
13
14
import os
14
15
import sys
15
16
import time
16
17
from typing import Any
17
18
from typing import Generator
19
+ from typing import Iterable
18
20
from typing import Literal
19
21
from typing import Sequence
20
22
from typing import TypedDict
@@ -66,7 +68,44 @@ def worker_title(title: str) -> None:
66
68
67
69
class Marker (enum .Enum ):
68
70
SHUTDOWN = 0
69
- QUEUE_REPLACED = 1
71
+
72
+
73
+ class TestQueue :
74
+ """A simple queue that can be inspected and modified while the lock is held."""
75
+
76
+ Item = int | Literal [Marker .SHUTDOWN ]
77
+
78
+ def __init__ (self , execmodel : execnet .gateway_base .ExecModel ):
79
+ self ._items : collections .deque [TestQueue .Item ] = collections .deque ()
80
+ self ._lock = execmodel .RLock () # type: ignore[no-untyped-call]
81
+ self ._has_items_event = execmodel .Event ()
82
+
83
+ def get (self ) -> Item :
84
+ while True :
85
+ with self .lock () as locked_items :
86
+ if locked_items :
87
+ return locked_items .popleft ()
88
+
89
+ self ._has_items_event .wait ()
90
+
91
+ def put (self , item : Item ) -> None :
92
+ with self .lock () as locked_items :
93
+ locked_items .append (item )
94
+
95
+ def replace (self , iterable : Iterable [Item ]) -> None :
96
+ with self .lock ():
97
+ self ._items = collections .deque (iterable )
98
+
99
+ @contextlib .contextmanager
100
+ def lock (self ) -> Generator [collections .deque [Item ], None , None ]:
101
+ with self ._lock :
102
+ try :
103
+ yield self ._items
104
+ finally :
105
+ if self ._items :
106
+ self ._has_items_event .set ()
107
+ else :
108
+ self ._has_items_event .clear ()
70
109
71
110
72
111
class WorkerInteractor :
@@ -77,22 +116,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77
116
self .testrunuid = workerinput ["testrunuid" ]
78
117
self .log = Producer (f"worker-{ self .workerid } " , enabled = config .option .debug )
79
118
self .channel = channel
80
- self .torun = self ._make_queue ( )
119
+ self .torun = TestQueue ( self .channel . gateway . execmodel )
81
120
self .nextitem_index : int | None | Literal [Marker .SHUTDOWN ] = None
82
121
config .pluginmanager .register (self )
83
122
84
- def _make_queue (self ) -> Any :
85
- return self .channel .gateway .execmodel .queue .Queue ()
86
-
87
- def _get_next_item_index (self ) -> int | Literal [Marker .SHUTDOWN ]:
88
- """Gets the next item from test queue. Handles the case when the queue
89
- is replaced concurrently in another thread.
90
- """
91
- result = self .torun .get ()
92
- while result is Marker .QUEUE_REPLACED :
93
- result = self .torun .get ()
94
- return result # type: ignore[no-any-return]
95
-
96
123
def sendevent (self , name : str , ** kwargs : object ) -> None :
97
124
self .log ("sending" , name , kwargs )
98
125
self .channel .send ((name , kwargs ))
@@ -146,30 +173,25 @@ def handle_command(
146
173
self .steal (kwargs ["indices" ])
147
174
148
175
def steal (self , indices : Sequence [int ]) -> None :
149
- indices_set = set (indices )
150
- stolen = []
151
-
152
- old_queue , self .torun = self .torun , self ._make_queue ()
153
-
154
- def old_queue_get_nowait_noraise () -> int | None :
155
- with contextlib .suppress (self .channel .gateway .execmodel .queue .Empty ):
156
- return old_queue .get_nowait () # type: ignore[no-any-return]
157
- return None
158
-
159
- for i in iter (old_queue_get_nowait_noraise , None ):
160
- if i in indices_set :
161
- stolen .append (i )
176
+ with self .torun .lock () as locked_queue :
177
+ requested_set = set (indices )
178
+ stolen = list (item for item in locked_queue if item in requested_set )
179
+
180
+ # Stealing only if all requested tests are still pending
181
+ if len (stolen ) == len (requested_set ):
182
+ self .torun .replace (
183
+ item for item in locked_queue if item not in requested_set
184
+ )
162
185
else :
163
- self . torun . put ( i )
186
+ stolen = []
164
187
165
188
self .sendevent ("unscheduled" , indices = stolen )
166
- old_queue .put (Marker .QUEUE_REPLACED )
167
189
168
190
@pytest .hookimpl
169
191
def pytest_runtestloop (self , session : pytest .Session ) -> bool :
170
192
self .log ("entering main loop" )
171
193
self .channel .setcallback (self .handle_command , endmarker = Marker .SHUTDOWN )
172
- self .nextitem_index = self ._get_next_item_index ()
194
+ self .nextitem_index = self .torun . get ()
173
195
while self .nextitem_index is not Marker .SHUTDOWN :
174
196
self .run_one_test ()
175
197
if session .shouldfail or session .shouldstop :
@@ -179,7 +201,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179
201
def run_one_test (self ) -> None :
180
202
assert isinstance (self .nextitem_index , int )
181
203
self .item_index = self .nextitem_index
182
- self .nextitem_index = self ._get_next_item_index ()
204
+ self .nextitem_index = self .torun . get ()
183
205
184
206
items = self .session .items
185
207
item = items [self .item_index ]
0 commit comments