Skip to content

Commit deeaecf

Browse files
committed
Add matplotlib to requirements
1 parent 5426c87 commit deeaecf

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tabulate~=0.9.0
22
GitPython~=3.1.43
3+
matplotlib~=3.10.0

scripts/compare-llama-bench.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@
129129

130130
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
131131

132-
# Check for matplotlib if plotting is requested
133132
if known_args.plot:
134133
try:
135134
import matplotlib.pyplot as plt
@@ -511,7 +510,6 @@ def valid_format(data_files: list[str]) -> bool:
511510

512511
name_compare = bench_data.get_commit_name(hexsha8_compare)
513512

514-
515513
# If the user provided columns to group the results by, use them:
516514
if known_args.show is not None:
517515
show = known_args.show.split(",")
@@ -556,6 +554,14 @@ def valid_format(data_files: list[str]) -> bool:
556554
show.remove(prop)
557555
except ValueError:
558556
pass
557+
558+
# add plot_x parameter to if it's not already there
559+
if known_args.plot:
560+
for k, v in PRETTY_NAMES.items():
561+
if v == known_args.plot_x and k not in show:
562+
show.append(k)
563+
break
564+
559565
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
560566

561567
if not rows_show:
@@ -629,7 +635,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629635
plot_x_label = plot_x_param
630636
else:
631637
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
632-
logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.")
633638
return
634639

635640
grouped_data = {}
@@ -671,7 +676,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671676

672677
group_key_parts.append(f"Test={test_name}")
673678

674-
group_key = tuple(sorted(group_key_parts))
679+
group_key = tuple(group_key_parts)
675680

676681
if group_key not in grouped_data:
677682
grouped_data[group_key] = []
@@ -692,7 +697,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692697
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
693698
rows = ceil(num_groups / cols)
694699

695-
# scale figure size by grid dimensions
700+
# Scale figure size by grid dimensions
696701
w, h = base_size
697702
fig, ax_arr = plt.subplots(rows, cols,
698703
figsize=(w * cols, h * rows),
@@ -726,7 +731,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726731
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
727732
label=f'{compare_name}', linewidth=2, markersize=6)
728733

729-
if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4:
734+
if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4:
730735
ax.set_xscale('log', base=2)
731736
unique_x = sorted(set(x_values))
732737
ax.set_xticks(unique_x)
@@ -741,7 +746,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741746
title = ', '.join(title_parts) if title_parts else "Performance comparison"
742747

743748
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
744-
ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold')
749+
ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
745750
ax.set_title(title, fontsize=12, fontweight='bold')
746751
ax.legend(loc='best', fontsize=10)
747752
ax.grid(True, alpha=0.3)
@@ -751,7 +756,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751756
for i in range(plot_idx, len(axes)):
752757
axes[i].set_visible(False)
753758

754-
fig.suptitle(f'Performance comparison: {compare_name} vs {baseline_name}',
759+
fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
755760
fontsize=14, fontweight='bold')
756761
fig.subplots_adjust(top=1)
757762

0 commit comments

Comments
 (0)