Skip to content

Commit daa0b11

Browse files
authored
Merge pull request #2433 from pbalcer/bench-stddev
[benchmarks] add support for stddev
2 parents aa72577 + 1e4e23a commit daa0b11

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

scripts/benchmarks/benches/compute.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def run(self, env_vars) -> list[Result]:
118118
result = self.run_bench(command, env_vars)
119119
parsed_results = self.parse_output(result)
120120
ret = []
121-
for label, mean, unit in parsed_results:
121+
for label, median, stddev, unit in parsed_results:
122122
extra_label = " CPU count" if parse_unit_type(unit) == "instr" else ""
123-
ret.append(Result(label=self.name() + extra_label, value=mean, command=command, env=env_vars, stdout=result, unit=parse_unit_type(unit)))
123+
ret.append(Result(label=self.name() + extra_label, value=median, stddev=stddev, command=command, env=env_vars, stdout=result, unit=parse_unit_type(unit)))
124124
return ret
125125

126126
def parse_output(self, output):
@@ -135,8 +135,11 @@ def parse_output(self, output):
135135
try:
136136
label = data_row[0]
137137
mean = float(data_row[1])
138+
median = float(data_row[2])
139+
# compute benchmarks report stddev as %
140+
stddev = mean * (float(data_row[3].strip('%')) / 100.0)
138141
unit = data_row[7]
139-
results.append((label, mean, unit))
142+
results.append((label, median, stddev, unit))
140143
except (ValueError, IndexError) as e:
141144
raise ValueError(f"Error parsing output: {e}")
142145
if len(results) == 0:

scripts/benchmarks/benches/result.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ class Result:
1818
stdout: str
1919
passed: bool = True
2020
unit: str = ""
21-
# values should not be set by the benchmark
21+
# stddev can be optionally set by the benchmark,
22+
# if not set, it will be calculated automatically.
23+
stddev: float = 0.0
24+
# values below should not be set by the benchmark
2225
name: str = ""
2326
lower_is_better: bool = True
2427
git_hash: str = ''
2528
date: Optional[datetime] = None
26-
stddev: float = 0.0
2729

2830
@dataclass_json
2931
@dataclass

scripts/benchmarks/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ def process_results(results: dict[str, list[Result]]) -> tuple[bool, list[Result
103103
rlist.sort(key=lambda res: res.value)
104104
median_index = len(rlist) // 2
105105
median_result = rlist[median_index]
106-
median_result.stddev = stddev
106+
107+
# only override the stddev if not already set
108+
if median_result.stddev == 0.0:
109+
median_result.stddev = stddev
107110

108111
processed.append(median_result)
109112

@@ -160,7 +163,6 @@ def main(directory, additional_env_vars, save_name, compare_names, filter):
160163
if valid:
161164
break
162165
results += processed
163-
164166
except Exception as e:
165167
if options.exit_on_failure:
166168
raise e

scripts/benchmarks/output_html.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,32 @@ def create_time_series_chart(benchmarks: list[BenchmarkSeries], github_repo: str
3232

3333
num_benchmarks = len(benchmarks)
3434
if num_benchmarks == 0:
35-
return
35+
return []
3636

3737
html_charts = []
3838

3939
for _, benchmark in enumerate(benchmarks):
4040
fig, ax = plt.subplots(figsize=(10, 4))
4141

42+
all_values = []
43+
all_stddevs = []
44+
4245
for run in benchmark.runs:
4346
sorted_points = sorted(run.results, key=lambda x: x.date)
4447
dates = [point.date for point in sorted_points]
4548
values = [point.value for point in sorted_points]
49+
stddevs = [point.stddev for point in sorted_points]
50+
51+
all_values.extend(values)
52+
all_stddevs.extend(stddevs)
4653

47-
ax.plot_date(dates, values, '-', label=run.name, alpha=0.5)
54+
ax.errorbar(dates, values, yerr=stddevs, fmt='-', label=run.name, alpha=0.5)
4855
scatter = ax.scatter(dates, values, picker=True)
4956

5057
tooltip_labels = [
5158
f"Date: {point.date.strftime('%Y-%m-%d %H:%M:%S')}\n"
52-
f"Value: {point.value:.2f}\n"
59+
f"Value: {point.value:.2f} {benchmark.metadata.unit}\n"
60+
f"Stddev: {point.stddev:.2f} {benchmark.metadata.unit}\n"
5361
f"Git Hash: {point.git_hash}"
5462
for point in sorted_points
5563
]
@@ -62,6 +70,13 @@ def create_time_series_chart(benchmarks: list[BenchmarkSeries], github_repo: str
6270
targets=targets)
6371
mpld3.plugins.connect(fig, tooltip)
6472

73+
# This is so that the stddev doesn't fill the entire y axis on the chart
74+
if all_values and all_stddevs:
75+
max_value = max(all_values)
76+
min_value = min(all_values)
77+
max_stddev = max(all_stddevs)
78+
ax.set_ylim(min_value - 3 * max_stddev, max_value + 3 * max_stddev)
79+
6580
ax.set_title(benchmark.label, pad=20)
6681
performance_indicator = "lower is better" if benchmark.metadata.lower_is_better else "higher is better"
6782
ax.text(0.5, 1.05, f"({performance_indicator})",
@@ -79,7 +94,7 @@ def create_time_series_chart(benchmarks: list[BenchmarkSeries], github_repo: str
7994
ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter('%Y-%m-%d %H:%M:%S'))
8095

8196
plt.tight_layout()
82-
html_charts.append(BenchmarkTimeSeries(html= mpld3.fig_to_html(fig), label= benchmark.label))
97+
html_charts.append(BenchmarkTimeSeries(html=mpld3.fig_to_html(fig), label=benchmark.label))
8398
plt.close(fig)
8499

85100
return html_charts

0 commit comments

Comments
 (0)