Skip to content

Commit bb5f0d1

Browse files
authored
MIPROv2 Updates (#8166)
* adding in error messages & timeout for user permission message * ruff fix
1 parent 524a58f commit bb5f0d1

File tree

1 file changed

+47
-19
lines changed

1 file changed

+47
-19
lines changed

dspy/teleprompt/mipro_optimizer_v2.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import textwrap
44
from collections import defaultdict
55
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
6+
import select
7+
import sys
8+
import time
69

710
import numpy as np
811
import optuna
912
from optuna.distributions import CategoricalDistribution
10-
1113
import dspy
1214
from dspy.evaluate.evaluate import Evaluate
1315
from dspy.propose import GroundedProposer
@@ -53,10 +55,8 @@ def __init__(
5355
teacher_settings: Dict = {},
5456
max_bootstrapped_demos: int = 4,
5557
max_labeled_demos: int = 4,
56-
auto: Optional[Literal["light", "medium", "heavy"]] = "medium",
57-
num_candidates: int = 10,
58-
num_fewshot_candidates: Optional[int] = None,
59-
num_instruct_candidates: Optional[int] = None,
58+
auto: Optional[Literal["light", "medium", "heavy"]] = "light",
59+
num_candidates: Optional[int] = None,
6060
num_threads: Optional[int] = None,
6161
max_errors: int = 10,
6262
seed: int = 9,
@@ -71,9 +71,9 @@ def __init__(
7171
if auto not in allowed_modes:
7272
raise ValueError(f"Invalid value for auto: {auto}. Must be one of {allowed_modes}.")
7373
self.auto = auto
74-
75-
self.num_fewshot_candidates = num_fewshot_candidates or num_candidates
76-
self.num_instruct_candidates = num_instruct_candidates or num_candidates
74+
self.num_fewshot_candidates = num_candidates
75+
self.num_instruct_candidates = num_candidates
76+
self.num_candidates = num_candidates
7777
self.metric = metric
7878
self.init_temperature = init_temperature
7979
self.task_model = task_model if task_model else dspy.settings.lm
@@ -99,7 +99,7 @@ def compile(
9999
trainset: List,
100100
teacher: Any = None,
101101
valset: Optional[List] = None,
102-
num_trials: int = 30,
102+
num_trials: Optional[int] = None,
103103
max_bootstrapped_demos: Optional[int] = None,
104104
max_labeled_demos: Optional[int] = None,
105105
seed: Optional[int] = None,
@@ -114,6 +114,21 @@ def compile(
114114
requires_permission_to_run: bool = True,
115115
provide_traceback: Optional[bool] = None,
116116
) -> Any:
117+
118+
zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0)
119+
120+
# If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
121+
if self.auto is None and (self.num_candidates is not None and num_trials is None):
122+
raise ValueError(f"If auto is None, num_trials must also be provided. Given num_candidates={self.num_candidates}, we'd recommend setting num_trials to ~{self._set_num_trials_from_num_candidates(student, zeroshot_opt, self.num_candidates)}.")
123+
124+
# If auto is None, and num_candidates or num_trials is None, raise an error
125+
if self.auto is None and (self.num_candidates is None or num_trials is None):
126+
raise ValueError("If auto is None, num_candidates must also be provided.")
127+
128+
# If auto is provided, and either num_candidates or num_trials is not None, raise an error
129+
if self.auto is not None and (self.num_candidates is not None or num_trials is not None):
130+
raise ValueError("If auto is not None, num_candidates and num_trials cannot be set, since they would be overrided by the auto settings. Please either set auto to None, or do not specify num_candidates and num_trials.")
131+
117132
# Set random seeds
118133
seed = seed or self.seed
119134
self._set_random_seeds(seed)
@@ -128,7 +143,6 @@ def compile(
128143
trainset, valset = self._set_and_validate_datasets(trainset, valset)
129144

130145
# Set hyperparameters based on run mode (if set)
131-
zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0)
132146
num_trials, valset, minibatch = self._set_hyperparams_from_run_mode(
133147
student, num_trials, minibatch, zeroshot_opt, valset
134148
)
@@ -204,6 +218,15 @@ def _set_random_seeds(self, seed):
204218
self.rng = random.Random(seed)
205219
np.random.seed(seed)
206220

221+
def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candidates):
222+
num_vars = len(program.predictors())
223+
if not zeroshot_opt:
224+
num_vars *= 2 # Account for few-shot examples + instruction variables
225+
# Trials = MAX(c*M*log(N), c=2, 3/2*N)
226+
num_trials = int(max(2 * num_vars * np.log2(num_candidates), 1.5 * num_candidates))
227+
228+
return num_trials
229+
207230
def _set_hyperparams_from_run_mode(
208231
self,
209232
program: Any,
@@ -226,11 +249,7 @@ def _set_hyperparams_from_run_mode(
226249
self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
227250
self.num_fewshot_candidates = auto_settings["n"]
228251

229-
num_vars = len(program.predictors())
230-
if not zeroshot_opt:
231-
num_vars *= 2 # Account for few-shot examples + instruction variables
232-
# Trials = MAX(c*M*log(N), c=2, 3/2*N)
233-
num_trials = max(2 * num_vars * np.log(auto_settings["n"]), 1.5 * auto_settings["n"])
252+
num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"])
234253

235254
return num_trials, valset, minibatch
236255

@@ -353,17 +372,26 @@ def _get_user_confirmation(
353372
user_confirmation_message = textwrap.dedent(
354373
f"""\
355374
To proceed with the execution of this program, please confirm by typing {BLUE}'y'{ENDC} for yes or {BLUE}'n'{ENDC} for no.
375+
If no input is received within 20 seconds, the program will proceed automatically.
356376
357377
If you would like to bypass this confirmation step in future executions, set the {YELLOW}`requires_permission_to_run`{ENDC} flag to {YELLOW}`False`{ENDC} when calling compile.
358378
359379
{YELLOW}Awaiting your input...{ENDC}
360380
"""
361381
)
362382

363-
user_input = (
364-
input(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ").strip().lower()
365-
)
366-
return user_input == "y"
383+
print(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ", end='', flush=True)
384+
385+
# Wait for input with timeout
386+
start_time = time.time()
387+
while time.time() - start_time < 20:
388+
if select.select([sys.stdin], [], [], 0.1)[0]:
389+
user_input = sys.stdin.readline().strip().lower()
390+
return user_input == "y"
391+
time.sleep(0.1)
392+
393+
print("\nNo input received within 20 seconds. Proceeding with execution...")
394+
return True
367395

368396
def _bootstrap_fewshot_examples(self, program: Any, trainset: List, seed: int, teacher: Any) -> Optional[List]:
369397
logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==")

0 commit comments

Comments
 (0)