Skip to content

Commit 2e42be4

Browse files
authored
compare-llama-bench: add option to plot (#14169)
* compare llama-bench: add option to plot * Address review comments: convert case + add type hints * Add matplotlib to requirements * fix tests * Improve comment and fix assert condition for test * Add back default test_name, add --plot_log_scale * use log_scale regardless of x_values
1 parent fb85a28 commit 2e42be4

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tabulate~=0.9.0
22
GitPython~=3.1.43
3+
matplotlib~=3.10.0

scripts/compare-llama-bench.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
2020
raise e
2121

22+
2223
logger = logging.getLogger("compare-llama-bench")
2324

2425
# All llama-bench SQL fields
@@ -122,11 +123,15 @@
122123
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
123124
parser.add_argument("-s", "--show", help=help_s)
124125
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
126+
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
127+
parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
128+
parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
125129

126130
known_args, unknown_args = parser.parse_known_args()
127131

128132
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
129133

134+
130135
if known_args.check:
131136
# Check if all required Python libraries are installed. Would have failed earlier if not.
132137
sys.exit(0)
@@ -499,7 +504,6 @@ def valid_format(data_files: list[str]) -> bool:
499504

500505
name_compare = bench_data.get_commit_name(hexsha8_compare)
501506

502-
503507
# If the user provided columns to group the results by, use them:
504508
if known_args.show is not None:
505509
show = known_args.show.split(",")
@@ -544,6 +548,14 @@ def valid_format(data_files: list[str]) -> bool:
544548
show.remove(prop)
545549
except ValueError:
546550
pass
551+
552+
# Add plot_x parameter to parameters to show if it's not already present:
553+
if known_args.plot:
554+
for k, v in PRETTY_NAMES.items():
555+
if v == known_args.plot_x and k not in show:
556+
show.append(k)
557+
break
558+
547559
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
548560

549561
if not rows_show:
@@ -600,6 +612,161 @@ def valid_format(data_files: list[str]) -> bool:
600612
headers = [PRETTY_NAMES[p] for p in show]
601613
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
602614

615+
if known_args.plot:
616+
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False):
617+
try:
618+
import matplotlib.pyplot as plt
619+
import matplotlib
620+
matplotlib.use('Agg')
621+
except ImportError as e:
622+
logger.error("matplotlib is required for --plot.")
623+
raise e
624+
625+
data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
626+
plot_x_index = None
627+
plot_x_label = plot_x_param
628+
629+
if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
630+
pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
631+
if pretty_name in data_headers:
632+
plot_x_index = data_headers.index(pretty_name)
633+
plot_x_label = pretty_name
634+
elif plot_x_param in data_headers:
635+
plot_x_index = data_headers.index(plot_x_param)
636+
plot_x_label = plot_x_param
637+
else:
638+
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
639+
return
640+
641+
grouped_data = {}
642+
643+
for i, row in enumerate(table_data):
644+
group_key_parts = []
645+
test_name = row[-4]
646+
647+
base_test = ""
648+
x_value = None
649+
650+
if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
651+
for j, val in enumerate(row[:-4]):
652+
header_name = data_headers[j]
653+
if val is not None and str(val).strip():
654+
group_key_parts.append(f"{header_name}={val}")
655+
656+
if plot_x_param == "n_prompt" and "pp" in test_name:
657+
base_test = test_name.split("@")[0]
658+
x_value = base_test
659+
elif plot_x_param == "n_gen" and "tg" in test_name:
660+
x_value = test_name.split("@")[0]
661+
elif plot_x_param == "n_depth" and "@d" in test_name:
662+
base_test = test_name.split("@d")[0]
663+
x_value = int(test_name.split("@d")[1])
664+
else:
665+
base_test = test_name
666+
667+
if base_test.strip():
668+
group_key_parts.append(f"Test={base_test}")
669+
else:
670+
for j, val in enumerate(row[:-4]):
671+
if j != plot_x_index:
672+
header_name = data_headers[j]
673+
if val is not None and str(val).strip():
674+
group_key_parts.append(f"{header_name}={val}")
675+
else:
676+
x_value = val
677+
678+
group_key_parts.append(f"Test={test_name}")
679+
680+
group_key = tuple(group_key_parts)
681+
682+
if group_key not in grouped_data:
683+
grouped_data[group_key] = []
684+
685+
grouped_data[group_key].append({
686+
'x_value': x_value,
687+
'baseline': float(row[-3]),
688+
'compare': float(row[-2]),
689+
'speedup': float(row[-1])
690+
})
691+
692+
if not grouped_data:
693+
logger.error("No data available for plotting")
694+
return
695+
696+
def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
697+
from math import ceil
698+
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
699+
rows = ceil(num_groups / cols)
700+
701+
# Scale figure size by grid dimensions
702+
w, h = base_size
703+
fig, ax_arr = plt.subplots(rows, cols,
704+
figsize=(w * cols, h * rows),
705+
squeeze=False)
706+
707+
axes = ax_arr.flatten()[:num_groups]
708+
return fig, axes
709+
710+
num_groups = len(grouped_data)
711+
fig, axes = make_axes(num_groups)
712+
713+
plot_idx = 0
714+
715+
for group_key, points in grouped_data.items():
716+
if plot_idx >= len(axes):
717+
break
718+
ax = axes[plot_idx]
719+
720+
try:
721+
points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
722+
x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
723+
except ValueError:
724+
points_sorted = sorted(points, key=lambda p: group_key)
725+
x_values = [p['x_value'] for p in points_sorted]
726+
727+
baseline_vals = [p['baseline'] for p in points_sorted]
728+
compare_vals = [p['compare'] for p in points_sorted]
729+
730+
ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
731+
label=f'{baseline_name}', linewidth=2, markersize=6)
732+
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
733+
label=f'{compare_name}', linewidth=2, markersize=6)
734+
735+
if log_scale:
736+
ax.set_xscale('log', base=2)
737+
unique_x = sorted(set(x_values))
738+
ax.set_xticks(unique_x)
739+
ax.set_xticklabels([str(int(x)) for x in unique_x])
740+
741+
title_parts = []
742+
for part in group_key:
743+
if '=' in part:
744+
key, value = part.split('=', 1)
745+
title_parts.append(f"{key}: {value}")
746+
747+
title = ', '.join(title_parts) if title_parts else "Performance comparison"
748+
749+
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
750+
ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
751+
ax.set_title(title, fontsize=12, fontweight='bold')
752+
ax.legend(loc='best', fontsize=10)
753+
ax.grid(True, alpha=0.3)
754+
755+
plot_idx += 1
756+
757+
for i in range(plot_idx, len(axes)):
758+
axes[i].set_visible(False)
759+
760+
fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
761+
fontsize=14, fontweight='bold')
762+
fig.subplots_adjust(top=1)
763+
764+
plt.tight_layout()
765+
plt.savefig(output_file, dpi=300, bbox_inches='tight')
766+
plt.close()
767+
768+
create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale)
769+
603770
print(tabulate( # noqa: NP100
604771
table,
605772
headers=headers,

0 commit comments

Comments
 (0)