Skip to content

compare llama-bench: add option to plot #14169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-compare-llama-bench.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
tabulate~=0.9.0
GitPython~=3.1.43
matplotlib~=3.10.0
169 changes: 168 additions & 1 deletion scripts/compare-llama-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
raise e


logger = logging.getLogger("compare-llama-bench")

# All llama-bench SQL fields
Expand Down Expand Up @@ -122,11 +123,15 @@
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
parser.add_argument("-s", "--show", help=help_s)
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")

known_args, unknown_args = parser.parse_known_args()

logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)


if known_args.check:
# Check if all required Python libraries are installed. Would have failed earlier if not.
sys.exit(0)
Expand Down Expand Up @@ -499,7 +504,6 @@ def valid_format(data_files: list[str]) -> bool:

name_compare = bench_data.get_commit_name(hexsha8_compare)


# If the user provided columns to group the results by, use them:
if known_args.show is not None:
show = known_args.show.split(",")
Expand Down Expand Up @@ -544,6 +548,14 @@ def valid_format(data_files: list[str]) -> bool:
show.remove(prop)
except ValueError:
pass

# Add plot_x parameter to parameters to show if it's not already present:
if known_args.plot:
for k, v in PRETTY_NAMES.items():
if v == known_args.plot_x and k not in show:
show.append(k)
break

rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)

if not rows_show:
Expand Down Expand Up @@ -600,6 +612,161 @@ def valid_format(data_files: list[str]) -> bool:
headers = [PRETTY_NAMES[p] for p in show]
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]

if known_args.plot:
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False):
try:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
except ImportError as e:
logger.error("matplotlib is required for --plot.")
raise e

data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
plot_x_index = None
plot_x_label = plot_x_param

if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
if pretty_name in data_headers:
plot_x_index = data_headers.index(pretty_name)
plot_x_label = pretty_name
elif plot_x_param in data_headers:
plot_x_index = data_headers.index(plot_x_param)
plot_x_label = plot_x_param
else:
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
return

grouped_data = {}

for i, row in enumerate(table_data):
group_key_parts = []
test_name = row[-4]

base_test = ""
x_value = None

if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
for j, val in enumerate(row[:-4]):
header_name = data_headers[j]
if val is not None and str(val).strip():
group_key_parts.append(f"{header_name}={val}")

if plot_x_param == "n_prompt" and "pp" in test_name:
base_test = test_name.split("@")[0]
x_value = base_test
elif plot_x_param == "n_gen" and "tg" in test_name:
x_value = test_name.split("@")[0]
elif plot_x_param == "n_depth" and "@d" in test_name:
base_test = test_name.split("@d")[0]
x_value = int(test_name.split("@d")[1])
else:
base_test = test_name

if base_test.strip():
group_key_parts.append(f"Test={base_test}")
else:
for j, val in enumerate(row[:-4]):
if j != plot_x_index:
header_name = data_headers[j]
if val is not None and str(val).strip():
group_key_parts.append(f"{header_name}={val}")
else:
x_value = val

group_key_parts.append(f"Test={test_name}")

group_key = tuple(group_key_parts)

if group_key not in grouped_data:
grouped_data[group_key] = []

grouped_data[group_key].append({
'x_value': x_value,
'baseline': float(row[-3]),
'compare': float(row[-2]),
'speedup': float(row[-1])
})

if not grouped_data:
logger.error("No data available for plotting")
return

def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
from math import ceil
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
rows = ceil(num_groups / cols)

# Scale figure size by grid dimensions
w, h = base_size
fig, ax_arr = plt.subplots(rows, cols,
figsize=(w * cols, h * rows),
squeeze=False)

axes = ax_arr.flatten()[:num_groups]
return fig, axes

num_groups = len(grouped_data)
fig, axes = make_axes(num_groups)

plot_idx = 0

for group_key, points in grouped_data.items():
if plot_idx >= len(axes):
break
ax = axes[plot_idx]

try:
points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
except ValueError:
points_sorted = sorted(points, key=lambda p: group_key)
x_values = [p['x_value'] for p in points_sorted]

baseline_vals = [p['baseline'] for p in points_sorted]
compare_vals = [p['compare'] for p in points_sorted]

ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
label=f'{baseline_name}', linewidth=2, markersize=6)
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
label=f'{compare_name}', linewidth=2, markersize=6)

if log_scale:
ax.set_xscale('log', base=2)
unique_x = sorted(set(x_values))
ax.set_xticks(unique_x)
ax.set_xticklabels([str(int(x)) for x in unique_x])

title_parts = []
for part in group_key:
if '=' in part:
key, value = part.split('=', 1)
title_parts.append(f"{key}: {value}")

title = ', '.join(title_parts) if title_parts else "Performance comparison"

ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=12, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)

plot_idx += 1

for i in range(plot_idx, len(axes)):
axes[i].set_visible(False)

fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
fontsize=14, fontweight='bold')
fig.subplots_adjust(top=1)

plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
plt.close()

create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale)

print(tabulate( # noqa: NP100
table,
headers=headers,
Expand Down