Skip to content

Commit 5c2ee6a

Browse files
authored
Convert dspy.settings to a ContextVar, improve ParallelExecutor (isolate even if 1 thread), and permit user-launched threads (#1852)
* Convert dspy.settings to a ContextVar, improve ParallelExecutor (isolate even if 1 thread), and permit user-launched threads * Fixes
1 parent 0eb1e04 commit 5c2ee6a

File tree

3 files changed

+124
-118
lines changed

3 files changed

+124
-118
lines changed

dsp/utils/settings.py

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import copy
12
import threading
2-
from contextlib import contextmanager
3-
from copy import deepcopy
43

4+
from contextlib import contextmanager
5+
from contextvars import ContextVar
56
from dsp.utils.utils import dotdict
67

78
DEFAULT_CONFIG = dotdict(
@@ -27,85 +28,95 @@
2728
async_max_workers=8,
2829
)
2930

31+
# Global base configuration
32+
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)
33+
34+
# Initialize the context variable with an empty dict as default
35+
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())
36+
3037

3138
class Settings:
32-
"""DSP configuration settings."""
39+
"""
40+
A singleton class for DSPy configuration settings.
41+
42+
This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
43+
- If native threading is used, the thread inherits the initial config from the main thread.
44+
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
45+
"""
3346

3447
_instance = None
3548

3649
def __new__(cls):
37-
"""
38-
Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/
39-
"""
40-
4150
if cls._instance is None:
4251
cls._instance = super().__new__(cls)
43-
cls._instance.lock = threading.Lock()
44-
cls._instance.main_tid = threading.get_ident()
45-
cls._instance.main_stack = []
46-
cls._instance.stack_by_thread = {}
47-
cls._instance.stack_by_thread[threading.get_ident()] = cls._instance.main_stack
52+
cls._instance.lock = threading.Lock() # maintained here for assertions
53+
return cls._instance
4854

49-
# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
50-
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
51-
# downstream operations like dsp.retrieve would use configs from the defined pipeline.
55+
def __getattr__(self, name):
56+
overrides = dspy_ctx_overrides.get()
57+
if name in overrides:
58+
return overrides[name]
59+
elif name in main_thread_config:
60+
return main_thread_config[name]
61+
else:
62+
raise AttributeError(f"'Settings' object has no attribute '{name}'")
5263

53-
# make a deepcopy of the default config to avoid modifying the default config
54-
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
64+
def __setattr__(self, name, value):
65+
if name in ('_instance',):
66+
super().__setattr__(name, value)
67+
else:
68+
self.configure(**{name: value})
5569

56-
return cls._instance
70+
# Dictionary-like access
5771

58-
@property
59-
def config(self):
60-
thread_id = threading.get_ident()
61-
if thread_id not in self.stack_by_thread:
62-
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
63-
return self.stack_by_thread[thread_id][-1]
72+
def __getitem__(self, key):
73+
return self.__getattr__(key)
6474

65-
def __getattr__(self, name):
66-
if hasattr(self.config, name):
67-
return getattr(self.config, name)
75+
def __setitem__(self, key, value):
76+
self.__setattr__(key, value)
6877

69-
if name in self.config:
70-
return self.config[name]
78+
def __contains__(self, key):
79+
overrides = dspy_ctx_overrides.get()
80+
return key in overrides or key in main_thread_config
7181

72-
super().__getattr__(name)
82+
def get(self, key, default=None):
83+
try:
84+
return self[key]
85+
except AttributeError:
86+
return default
7387

74-
def __append(self, config):
75-
thread_id = threading.get_ident()
76-
if thread_id not in self.stack_by_thread:
77-
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
78-
self.stack_by_thread[thread_id].append(config)
88+
def copy(self):
89+
overrides = dspy_ctx_overrides.get()
90+
return dotdict({**main_thread_config, **overrides})
7991

80-
def __pop(self):
81-
thread_id = threading.get_ident()
82-
if thread_id in self.stack_by_thread:
83-
self.stack_by_thread[thread_id].pop()
92+
# Configuration methods
8493

85-
def configure(self, inherit_config: bool = True, **kwargs):
86-
"""Set configuration settings.
94+
def configure(self, return_token=False, **kwargs):
95+
global main_thread_config
96+
overrides = dspy_ctx_overrides.get()
97+
new_overrides = dotdict({**main_thread_config, **overrides, **kwargs})
98+
token = dspy_ctx_overrides.set(new_overrides)
8799

88-
Args:
89-
inherit_config (bool, optional): Set configurations for the given, and use existing configurations for the rest. Defaults to True.
90-
"""
91-
if inherit_config:
92-
config = {**self.config, **kwargs}
93-
else:
94-
config = {**kwargs}
100+
# Update main_thread_config, in the main thread only
101+
if threading.current_thread() is threading.main_thread():
102+
main_thread_config = new_overrides
95103

96-
self.__append(config)
104+
if return_token:
105+
return token
97106

98107
@contextmanager
99-
def context(self, inherit_config=True, **kwargs):
100-
self.configure(inherit_config=inherit_config, **kwargs)
101-
108+
def context(self, **kwargs):
109+
"""Context manager for temporary configuration changes."""
110+
token = self.configure(return_token=True, **kwargs)
102111
try:
103112
yield
104113
finally:
105-
self.__pop()
114+
dspy_ctx_overrides.reset(token)
106115

107-
def __repr__(self) -> str:
108-
return repr(self.config)
116+
def __repr__(self):
117+
overrides = dspy_ctx_overrides.get()
118+
combined_config = {**main_thread_config, **overrides}
119+
return repr(combined_config)
109120

110121

111-
settings = Settings()
122+
settings = Settings()

dspy/utils/asyncify.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,7 @@ def get_limiter():
2424

2525

2626
def asyncify(program):
27-
import dspy
2827
import threading
29-
30-
assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread"
31-
32-
def wrapped(*args, **kwargs):
33-
thread_stacks = dspy.settings.stack_by_thread
34-
current_thread_id = threading.get_ident()
35-
creating_new_thread = current_thread_id not in thread_stacks
36-
37-
assert creating_new_thread
38-
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)
39-
40-
try:
41-
return program(*args, **kwargs)
42-
finally:
43-
del thread_stacks[threading.get_ident()]
44-
45-
return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter())
28+
assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread"
29+
# NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py
30+
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())

dspy/utils/parallelizer.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
import logging
21
import sys
32
import tqdm
4-
import dspy
53
import signal
4+
import logging
65
import threading
76
import traceback
87
import contextlib
98

9+
from contextvars import copy_context
1010
from tqdm.contrib.logging import logging_redirect_tqdm
1111
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

13-
1413
logger = logging.getLogger(__name__)
1514

1615

@@ -23,6 +22,8 @@ def __init__(
2322
provide_traceback=False,
2423
compare_results=False,
2524
):
25+
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
26+
2627
self.num_threads = num_threads
2728
self.disable_progress_bar = disable_progress_bar
2829
self.max_errors = max_errors
@@ -33,34 +34,18 @@ def __init__(
3334
self.error_lock = threading.Lock()
3435
self.cancel_jobs = threading.Event()
3536

36-
3737
def execute(self, function, data):
3838
wrapped_function = self._wrap_function(function)
3939
if self.num_threads == 1:
40-
return self._execute_single_thread(wrapped_function, data)
40+
return self._execute_isolated_single_thread(wrapped_function, data)
4141
else:
4242
return self._execute_multi_thread(wrapped_function, data)
4343

44-
4544
def _wrap_function(self, function):
46-
# Wrap the function with threading context and error handling
47-
def wrapped(item, parent_id=None):
48-
thread_stacks = dspy.settings.stack_by_thread
49-
current_thread_id = threading.get_ident()
50-
creating_new_thread = current_thread_id not in thread_stacks
51-
52-
assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid
53-
54-
if creating_new_thread:
55-
# If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
56-
if parent_id and parent_id in thread_stacks:
57-
thread_stacks[current_thread_id] = list(thread_stacks[parent_id])
58-
else:
59-
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)
60-
61-
# TODO: Consider the behavior below.
62-
# import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))
63-
45+
# Wrap the function with error handling
46+
def wrapped(item):
47+
if self.cancel_jobs.is_set():
48+
return None
6449
try:
6550
return function(item)
6651
except Exception as e:
@@ -79,45 +64,53 @@ def wrapped(item, parent_id=None):
7964
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
8065
)
8166
return None
82-
finally:
83-
if creating_new_thread:
84-
del thread_stacks[threading.get_ident()]
8567
return wrapped
8668

87-
88-
def _execute_single_thread(self, function, data):
69+
def _execute_isolated_single_thread(self, function, data):
8970
results = []
9071
pbar = tqdm.tqdm(
9172
total=len(data),
9273
dynamic_ncols=True,
9374
disable=self.disable_progress_bar,
94-
file=sys.stdout,
75+
file=sys.stdout
9576
)
77+
9678
for item in data:
9779
with logging_redirect_tqdm():
9880
if self.cancel_jobs.is_set():
9981
break
100-
result = function(item)
82+
83+
# Create an isolated context for each task
84+
task_ctx = copy_context()
85+
result = task_ctx.run(function, item)
10186
results.append(result)
87+
10288
if self.compare_results:
10389
# Assumes score is the last element of the result tuple
104-
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in data if r is not None]))
90+
self._update_progress(
91+
pbar,
92+
sum([r[-1] for r in results if r is not None]),
93+
len([r for r in data if r is not None]),
94+
)
10595
else:
10696
self._update_progress(pbar, len(results), len(data))
97+
10798
pbar.close()
99+
108100
if self.cancel_jobs.is_set():
109101
logger.warning("Execution was cancelled due to errors.")
110102
raise Exception("Execution was cancelled due to errors.")
111-
return results
112103

104+
return results
113105

114106
def _update_progress(self, pbar, nresults, ntotal):
115107
if self.compare_results:
116-
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({round(100 * nresults / ntotal, 1) if ntotal > 0 else 0}%)")
108+
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
109+
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
117110
else:
118111
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
119-
pbar.update()
120112

113+
pbar.update()
121114

122115
def _execute_multi_thread(self, function, data):
123116
results = [None] * len(data) # Pre-allocate results list to maintain order
@@ -132,6 +125,7 @@ def interrupt_handler_manager():
132125
def interrupt_handler(sig, frame):
133126
self.cancel_jobs.set()
134127
logger.warning("Received SIGINT. Cancelling execution.")
128+
# Re-raise the signal to allow default behavior
135129
default_handler(sig, frame)
136130

137131
signal.signal(signal.SIGINT, interrupt_handler)
@@ -143,37 +137,53 @@ def interrupt_handler(sig, frame):
143137
# If not in the main thread, skip setting signal handlers
144138
yield
145139

146-
def cancellable_function(index_item, parent_id=None):
140+
def cancellable_function(index_item):
147141
index, item = index_item
148142
if self.cancel_jobs.is_set():
149143
return index, job_cancelled
150-
return index, function(item, parent_id)
151-
152-
parent_id = threading.get_ident() if threading.current_thread() is not threading.main_thread() else None
144+
return index, function(item)
153145

154146
with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
155-
futures = {executor.submit(cancellable_function, pair, parent_id): pair for pair in enumerate(data)}
147+
futures = {}
148+
for pair in enumerate(data):
149+
# Capture the context for each task
150+
task_ctx = copy_context()
151+
future = executor.submit(task_ctx.run, cancellable_function, pair)
152+
futures[future] = pair
153+
156154
pbar = tqdm.tqdm(
157155
total=len(data),
158156
dynamic_ncols=True,
159157
disable=self.disable_progress_bar,
160-
file=sys.stdout,
158+
file=sys.stdout
161159
)
162160

163161
for future in as_completed(futures):
164162
index, result = future.result()
165-
163+
166164
if result is job_cancelled:
167165
continue
166+
168167
results[index] = result
169168

170169
if self.compare_results:
171170
# Assumes score is the last element of the result tuple
172-
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in results if r is not None]))
171+
self._update_progress(
172+
pbar,
173+
sum([r[-1] for r in results if r is not None]),
174+
len([r for r in results if r is not None]),
175+
)
173176
else:
174-
self._update_progress(pbar, len([r for r in results if r is not None]), len(data))
177+
self._update_progress(
178+
pbar,
179+
len([r for r in results if r is not None]),
180+
len(data),
181+
)
182+
175183
pbar.close()
184+
176185
if self.cancel_jobs.is_set():
177186
logger.warning("Execution was cancelled due to errors.")
178187
raise Exception("Execution was cancelled due to errors.")
188+
179189
return results

0 commit comments

Comments
 (0)