Skip to content

Commit 10042de

Browse files
committed
Add matplotlib to requirements
1 parent 5426c87 commit 10042de

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tabulate~=0.9.0
22
GitPython~=3.1.43
3+
matplotlib~=3.10.0

scripts/compare-llama-bench.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def valid_format(data_files: list[str]) -> bool:
511511

512512
name_compare = bench_data.get_commit_name(hexsha8_compare)
513513

514-
515514
# If the user provided columns to group the results by, use them:
516515
if known_args.show is not None:
517516
show = known_args.show.split(",")
@@ -556,6 +555,14 @@ def valid_format(data_files: list[str]) -> bool:
556555
show.remove(prop)
557556
except ValueError:
558557
pass
558+
559+
if known_args.plot:
560+
for k, v in PRETTY_NAMES.items():
561+
if v == known_args.plot_x and k not in show:
562+
logger.info(f"Adding {k} to --show")
563+
show.append(k)
564+
break
565+
559566
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
560567

561568
if not rows_show:
@@ -629,7 +636,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629636
plot_x_label = plot_x_param
630637
else:
631638
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
632-
logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.")
633639
return
634640

635641
grouped_data = {}
@@ -671,7 +677,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671677

672678
group_key_parts.append(f"Test={test_name}")
673679

674-
group_key = tuple(sorted(group_key_parts))
680+
group_key = tuple(group_key_parts)
675681

676682
if group_key not in grouped_data:
677683
grouped_data[group_key] = []
@@ -692,7 +698,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692698
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
693699
rows = ceil(num_groups / cols)
694700

695-
# scale figure size by grid dimensions
701+
# Scale figure size by grid dimensions
696702
w, h = base_size
697703
fig, ax_arr = plt.subplots(rows, cols,
698704
figsize=(w * cols, h * rows),
@@ -726,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726732
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
727733
label=f'{compare_name}', linewidth=2, markersize=6)
728734

729-
if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4:
735+
if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4:
730736
ax.set_xscale('log', base=2)
731737
unique_x = sorted(set(x_values))
732738
ax.set_xticks(unique_x)
@@ -741,7 +747,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741747
title = ', '.join(title_parts) if title_parts else "Performance comparison"
742748

743749
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
744-
ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold')
750+
ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
745751
ax.set_title(title, fontsize=12, fontweight='bold')
746752
ax.legend(loc='best', fontsize=10)
747753
ax.grid(True, alpha=0.3)
@@ -751,7 +757,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751757
for i in range(plot_idx, len(axes)):
752758
axes[i].set_visible(False)
753759

754-
fig.suptitle(f'Performance comparison: {compare_name} vs {baseline_name}',
760+
fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
755761
fontsize=14, fontweight='bold')
756762
fig.subplots_adjust(top=1)
757763

0 commit comments

Comments
 (0)