Skip to content

Commit 1690177

Browse files
committed
compare llama-bench: add option to plot
1 parent d714dad commit 1690177

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed

scripts/compare-llama-bench.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,23 @@
122122
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
123123
parser.add_argument("-s", "--show", help=help_s)
124124
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
125+
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
126+
parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth")
125127

126128
known_args, unknown_args = parser.parse_known_args()
127129

128130
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
129131

132+
# Check for matplotlib if plotting is requested
133+
if known_args.plot:
134+
try:
135+
import matplotlib.pyplot as plt
136+
import matplotlib
137+
matplotlib.use('Agg')
138+
except ImportError as e:
139+
print("matplotlib is required for --plot.")
140+
raise e
141+
130142
if known_args.check:
131143
# Check if all required Python libraries are installed. Would have failed earlier if not.
132144
sys.exit(0)
@@ -600,6 +612,157 @@ def valid_format(data_files: list[str]) -> bool:
600612
headers = [PRETTY_NAMES[p] for p in show]
601613
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
602614

615+
if known_args.plot:
616+
def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param):
617+
618+
data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
619+
plot_x_index = None
620+
plot_x_label = plot_x_param
621+
622+
if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
623+
pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
624+
if pretty_name in data_headers:
625+
plot_x_index = data_headers.index(pretty_name)
626+
plot_x_label = pretty_name
627+
elif plot_x_param in data_headers:
628+
plot_x_index = data_headers.index(plot_x_param)
629+
plot_x_label = plot_x_param
630+
else:
631+
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
632+
logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.")
633+
return
634+
635+
grouped_data = {}
636+
637+
for i, row in enumerate(table_data):
638+
group_key_parts = []
639+
test_name = row[-4]
640+
641+
if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
642+
for j, val in enumerate(row[:-4]):
643+
header_name = data_headers[j]
644+
if val is not None and str(val).strip():
645+
group_key_parts.append(f"{header_name}={val}")
646+
647+
if plot_x_param == "n_prompt":
648+
assert "pp" in test_name, f"n_prompt test name {test_name} does not contain 'pp'"
649+
base_test = test_name.split("@")[0]
650+
x_value = base_test
651+
elif plot_x_param == "n_gen" and "tg" in test_name:
652+
assert "tg" in test_name, f"n_gen test name {test_name} does not contain 'tg'"
653+
x_value = test_name.split("@")[0]
654+
elif plot_x_param == "n_depth" and "@d" in test_name:
655+
assert "@d" in test_name, f"n_depth test name {test_name} does not contain '@d'"
656+
base_test = test_name.split("@d")[0]
657+
x_value = int(test_name.split("@d")[1])
658+
else:
659+
base_test = test_name
660+
661+
if base_test.strip():
662+
group_key_parts.append(f"Test={base_test}")
663+
else:
664+
for j, val in enumerate(row[:-4]):
665+
if j != plot_x_index:
666+
header_name = data_headers[j]
667+
if val is not None and str(val).strip():
668+
group_key_parts.append(f"{header_name}={val}")
669+
else:
670+
x_value = val
671+
672+
group_key_parts.append(f"Test={test_name}")
673+
674+
group_key = tuple(sorted(group_key_parts))
675+
676+
if group_key not in grouped_data:
677+
grouped_data[group_key] = []
678+
679+
grouped_data[group_key].append({
680+
'x_value': x_value,
681+
'baseline': float(row[-3]),
682+
'compare': float(row[-2]),
683+
'speedup': float(row[-1])
684+
})
685+
686+
if not grouped_data:
687+
logger.error("No data available for plotting")
688+
return
689+
690+
691+
def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692+
from math import ceil
693+
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
694+
rows = ceil(num_groups / cols)
695+
696+
# scale figure size by grid dimensions
697+
w, h = base_size
698+
fig, ax_arr = plt.subplots(rows, cols,
699+
figsize=(w * cols, h * rows),
700+
squeeze=False)
701+
702+
axes = ax_arr.flatten()[:num_groups]
703+
return fig, axes
704+
705+
num_groups = len(grouped_data)
706+
fig, axes = make_axes(num_groups)
707+
708+
plot_idx = 0
709+
710+
for group_key, points in grouped_data.items():
711+
if plot_idx >= len(axes):
712+
break
713+
ax = axes[plot_idx]
714+
715+
try:
716+
points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
717+
x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
718+
except ValueError:
719+
points_sorted = sorted(points, key=lambda p: group_key)
720+
x_values = [p['x_value'] for p in points_sorted]
721+
722+
baseline_vals = [p['baseline'] for p in points_sorted]
723+
compare_vals = [p['compare'] for p in points_sorted]
724+
725+
ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
726+
label=f'{baseline_name}', linewidth=2, markersize=6)
727+
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
728+
label=f'{compare_name}', linewidth=2, markersize=6)
729+
730+
if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4:
731+
ax.set_xscale('log', base=2)
732+
unique_x = sorted(set(x_values))
733+
ax.set_xticks(unique_x)
734+
ax.set_xticklabels([str(int(x)) for x in unique_x])
735+
736+
title_parts = []
737+
for part in group_key:
738+
if '=' in part:
739+
key, value = part.split('=', 1)
740+
title_parts.append(f"{key}: {value}")
741+
742+
title = ', '.join(title_parts) if title_parts else "Performance Comparison"
743+
744+
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
745+
ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold')
746+
ax.set_title(title, fontsize=12, fontweight='bold')
747+
ax.legend(loc='best', fontsize=10)
748+
ax.grid(True, alpha=0.3)
749+
750+
plot_idx += 1
751+
752+
for i in range(plot_idx, len(axes)):
753+
axes[i].set_visible(False)
754+
755+
fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}',
756+
fontsize=14, fontweight='bold')
757+
fig.subplots_adjust(top=1)
758+
759+
760+
plt.tight_layout()
761+
plt.savefig(output_file, dpi=300, bbox_inches='tight')
762+
plt.close()
763+
764+
create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x)
765+
603766
print(tabulate( # noqa: NP100
604767
table,
605768
headers=headers,

0 commit comments

Comments
 (0)