Skip to content

Commit 8805035

Browse files
Making bootstrap_iters an arg (#697)
* Making bootstrap iters an arg * another hard coded value in info loggers * styling plus handling of 0 iter case
1 parent 9619194 commit 8805035

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/lighteval/logging/info_loggers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,13 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int =
523523
else:
524524
self.metric_aggregated[task_name][metric_name] = metric_result
525525

526-
if isinstance(metric_result, dict):
527-
stderr = None # We skip stderr for some corpus metrics that return dicts
526+
if isinstance(metric_result, dict) or bootstrap_iters == 0:
527+
stderr = (
528+
None # We skip stderr for some corpus metrics that return dicts, or if bootstrap_iters is 0
529+
)
528530
else:
529531
aggregation = task.aggregation()[metric_name]
530-
stderr = get_stderr_function(aggregation=aggregation, number_experiments=1000)
532+
stderr = get_stderr_function(aggregation=aggregation, number_experiments=bootstrap_iters)
531533
if stderr is not None and len(metric_values) > 1:
532534
try:
533535
self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(metric_values)

src/lighteval/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class PipelineParameters:
107107
system_prompt: str | None = None
108108
cot_prompt: str | None = None
109109
load_responses_from_details_date_id: str | None = None
110+
bootstrap_iters: int = 1000
110111

111112
def __post_init__(self): # noqa C901
112113
if self.launcher_type == ParallelismManager.ACCELERATE:
@@ -293,7 +294,9 @@ def evaluate(self):
293294

294295
if self.is_main_process():
295296
self.evaluation_tracker.general_config_logger.log_end_time()
296-
self.evaluation_tracker.metrics_logger.aggregate(task_dict=self.task_dict, bootstrap_iters=1000)
297+
self.evaluation_tracker.metrics_logger.aggregate(
298+
task_dict=self.task_dict, bootstrap_iters=self.pipeline_parameters.bootstrap_iters
299+
)
297300
self.evaluation_tracker.details_logger.aggregate()
298301

299302
def _unpack(self, x):

0 commit comments

Comments
 (0)