1
- import logging
2
1
import sys
3
2
import tqdm
4
- import dspy
5
3
import signal
4
+ import logging
6
5
import threading
7
6
import traceback
8
7
import contextlib
9
8
9
+ from contextvars import copy_context
10
10
from tqdm .contrib .logging import logging_redirect_tqdm
11
11
from concurrent .futures import ThreadPoolExecutor , as_completed
12
12
13
-
14
13
logger = logging .getLogger (__name__ )
15
14
16
15
@@ -23,6 +22,8 @@ def __init__(
23
22
provide_traceback = False ,
24
23
compare_results = False ,
25
24
):
25
+ """Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
26
+
26
27
self .num_threads = num_threads
27
28
self .disable_progress_bar = disable_progress_bar
28
29
self .max_errors = max_errors
@@ -33,34 +34,18 @@ def __init__(
33
34
self .error_lock = threading .Lock ()
34
35
self .cancel_jobs = threading .Event ()
35
36
36
-
37
37
def execute (self , function , data ):
38
38
wrapped_function = self ._wrap_function (function )
39
39
if self .num_threads == 1 :
40
- return self ._execute_single_thread (wrapped_function , data )
40
+ return self ._execute_isolated_single_thread (wrapped_function , data )
41
41
else :
42
42
return self ._execute_multi_thread (wrapped_function , data )
43
43
44
-
45
44
def _wrap_function (self , function ):
46
- # Wrap the function with threading context and error handling
47
- def wrapped (item , parent_id = None ):
48
- thread_stacks = dspy .settings .stack_by_thread
49
- current_thread_id = threading .get_ident ()
50
- creating_new_thread = current_thread_id not in thread_stacks
51
-
52
- assert creating_new_thread or threading .get_ident () == dspy .settings .main_tid
53
-
54
- if creating_new_thread :
55
- # If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
56
- if parent_id and parent_id in thread_stacks :
57
- thread_stacks [current_thread_id ] = list (thread_stacks [parent_id ])
58
- else :
59
- thread_stacks [current_thread_id ] = list (dspy .settings .main_stack )
60
-
61
- # TODO: Consider the behavior below.
62
- # import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))
63
-
45
+ # Wrap the function with error handling
46
+ def wrapped (item ):
47
+ if self .cancel_jobs .is_set ():
48
+ return None
64
49
try :
65
50
return function (item )
66
51
except Exception as e :
@@ -79,45 +64,53 @@ def wrapped(item, parent_id=None):
79
64
f"Error processing item { item } : { e } . Set `provide_traceback=True` to see the stack trace."
80
65
)
81
66
return None
82
- finally :
83
- if creating_new_thread :
84
- del thread_stacks [threading .get_ident ()]
85
67
return wrapped
86
68
87
-
88
- def _execute_single_thread (self , function , data ):
69
+ def _execute_isolated_single_thread (self , function , data ):
89
70
results = []
90
71
pbar = tqdm .tqdm (
91
72
total = len (data ),
92
73
dynamic_ncols = True ,
93
74
disable = self .disable_progress_bar ,
94
- file = sys .stdout ,
75
+ file = sys .stdout
95
76
)
77
+
96
78
for item in data :
97
79
with logging_redirect_tqdm ():
98
80
if self .cancel_jobs .is_set ():
99
81
break
100
- result = function (item )
82
+
83
+ # Create an isolated context for each task
84
+ task_ctx = copy_context ()
85
+ result = task_ctx .run (function , item )
101
86
results .append (result )
87
+
102
88
if self .compare_results :
103
89
# Assumes score is the last element of the result tuple
104
- self ._update_progress (pbar , sum ([r [- 1 ] for r in results if r is not None ]), len ([r for r in data if r is not None ]))
90
+ self ._update_progress (
91
+ pbar ,
92
+ sum ([r [- 1 ] for r in results if r is not None ]),
93
+ len ([r for r in data if r is not None ]),
94
+ )
105
95
else :
106
96
self ._update_progress (pbar , len (results ), len (data ))
97
+
107
98
pbar .close ()
99
+
108
100
if self .cancel_jobs .is_set ():
109
101
logger .warning ("Execution was cancelled due to errors." )
110
102
raise Exception ("Execution was cancelled due to errors." )
111
- return results
112
103
104
+ return results
113
105
114
106
def _update_progress (self , pbar , nresults , ntotal ):
115
107
if self .compare_results :
116
- pbar .set_description (f"Average Metric: { nresults :.2f} / { ntotal } ({ round (100 * nresults / ntotal , 1 ) if ntotal > 0 else 0 } %)" )
108
+ percentage = round (100 * nresults / ntotal , 1 ) if ntotal > 0 else 0
109
+ pbar .set_description (f"Average Metric: { nresults :.2f} / { ntotal } ({ percentage } %)" )
117
110
else :
118
111
pbar .set_description (f"Processed { nresults } / { ntotal } examples" )
119
- pbar .update ()
120
112
113
+ pbar .update ()
121
114
122
115
def _execute_multi_thread (self , function , data ):
123
116
results = [None ] * len (data ) # Pre-allocate results list to maintain order
@@ -132,6 +125,7 @@ def interrupt_handler_manager():
132
125
def interrupt_handler (sig , frame ):
133
126
self .cancel_jobs .set ()
134
127
logger .warning ("Received SIGINT. Cancelling execution." )
128
+ # Re-raise the signal to allow default behavior
135
129
default_handler (sig , frame )
136
130
137
131
signal .signal (signal .SIGINT , interrupt_handler )
@@ -143,37 +137,53 @@ def interrupt_handler(sig, frame):
143
137
# If not in the main thread, skip setting signal handlers
144
138
yield
145
139
146
- def cancellable_function (index_item , parent_id = None ):
140
+ def cancellable_function (index_item ):
147
141
index , item = index_item
148
142
if self .cancel_jobs .is_set ():
149
143
return index , job_cancelled
150
- return index , function (item , parent_id )
151
-
152
- parent_id = threading .get_ident () if threading .current_thread () is not threading .main_thread () else None
144
+ return index , function (item )
153
145
154
146
with ThreadPoolExecutor (max_workers = self .num_threads ) as executor , interrupt_handler_manager ():
155
- futures = {executor .submit (cancellable_function , pair , parent_id ): pair for pair in enumerate (data )}
147
+ futures = {}
148
+ for pair in enumerate (data ):
149
+ # Capture the context for each task
150
+ task_ctx = copy_context ()
151
+ future = executor .submit (task_ctx .run , cancellable_function , pair )
152
+ futures [future ] = pair
153
+
156
154
pbar = tqdm .tqdm (
157
155
total = len (data ),
158
156
dynamic_ncols = True ,
159
157
disable = self .disable_progress_bar ,
160
- file = sys .stdout ,
158
+ file = sys .stdout
161
159
)
162
160
163
161
for future in as_completed (futures ):
164
162
index , result = future .result ()
165
-
163
+
166
164
if result is job_cancelled :
167
165
continue
166
+
168
167
results [index ] = result
169
168
170
169
if self .compare_results :
171
170
# Assumes score is the last element of the result tuple
172
- self ._update_progress (pbar , sum ([r [- 1 ] for r in results if r is not None ]), len ([r for r in results if r is not None ]))
171
+ self ._update_progress (
172
+ pbar ,
173
+ sum ([r [- 1 ] for r in results if r is not None ]),
174
+ len ([r for r in results if r is not None ]),
175
+ )
173
176
else :
174
- self ._update_progress (pbar , len ([r for r in results if r is not None ]), len (data ))
177
+ self ._update_progress (
178
+ pbar ,
179
+ len ([r for r in results if r is not None ]),
180
+ len (data ),
181
+ )
182
+
175
183
pbar .close ()
184
+
176
185
if self .cancel_jobs .is_set ():
177
186
logger .warning ("Execution was cancelled due to errors." )
178
187
raise Exception ("Execution was cancelled due to errors." )
188
+
179
189
return results
0 commit comments