Skip to content

Commit 4f56458

Browse files
Python script to compare commits with llama-bench (#4844)
1 parent 6efb8eb commit 4f56458

File tree

1 file changed

+356
-0
lines changed

1 file changed

+356
-0
lines changed

scripts/compare-llama-bench.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import heapq
5+
import sys
6+
import os
7+
from glob import glob
8+
import sqlite3
9+
10+
try:
11+
import git
12+
from tabulate import tabulate
13+
except ImportError:
14+
print("ERROR: the following Python libraries are required: GitPython, tabulate.")
15+
sys.exit(1)
16+
17+
# Properties by which to differentiate results per commit:
18+
KEY_PROPERTIES = [
19+
"cuda", "opencl", "metal", "gpu_blas", "blas", "cpu_info", "gpu_info", "model_filename",
20+
"model_type", "model_size", "model_n_params", "n_batch", "n_threads", "type_k", "type_v",
21+
"n_gpu_layers", "main_gpu", "no_kv_offload", "mul_mat_q", "tensor_split", "n_prompt", "n_gen"
22+
]
23+
24+
# Properties that are boolean and are converted to Yes/No for the table:
25+
BOOL_PROPERTIES = ["cuda", "opencl", "metal", "gpu_blas", "blas"]
26+
27+
# Header names for the table:
28+
PRETTY_NAMES = {
29+
"cuda": "CUDA", "opencl": "OpenCL", "metal": "Metal", "gpu_blas": "GPU BLAS", "blas": "BLAS",
30+
"cpu_info": "CPU", "gpu_info": "GPU", "model_filename": "File", "model_type": "Model",
31+
"model_size": "Model Size [GiB]", "model_n_params": "Num. of Parameters",
32+
"n_batch": "Batch size", "n_threads": "Threads", "type_k": "K type", "type_v": "V type",
33+
"n_gpu_layers": "GPU layers", "main_gpu": "Main GPU", "no_kv_offload": "NKVO",
34+
"mul_mat_q": "MMQ", "tensor_split": "Tensor split"
35+
}
36+
37+
DEFAULT_SHOW = ["model_type"] # Always show these properties by default.
38+
DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
39+
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
40+
41+
DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
42+
43+
$ git checkout master
44+
$ make clean && make llama-bench
45+
$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
46+
$ git checkout some_branch
47+
$ make clean && make llama-bench
48+
$ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
49+
$ ./scripts/compare-llama-bench.py
50+
51+
Performance numbers from multiple runs per commit are averaged WITHOUT being weighted by the --repetitions parameter of llama-bench.
52+
"""
53+
54+
parser = argparse.ArgumentParser(
55+
description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter)
56+
help_b = (
57+
"The baseline commit to compare performance to. "
58+
"Accepts either a branch name, tag name, or commit hash. "
59+
"Defaults to latest master commit with data."
60+
)
61+
parser.add_argument("-b", "--baseline", help=help_b)
62+
help_c = (
63+
"The commit whose performance is to be compared to the baseline. "
64+
"Accepts either a branch name, tag name, or commit hash. "
65+
"Defaults to the non-master commit for which llama-bench was run most recently."
66+
)
67+
parser.add_argument("-c", "--compare", help=help_c)
68+
help_i = (
69+
"Input SQLite file for comparing commits. "
70+
"Defaults to 'llama-bench.sqlite' in the current working directory. "
71+
"If no such file is found and there is exactly one .sqlite file in the current directory, "
72+
"that file is instead used as input."
73+
)
74+
parser.add_argument("-i", "--input", help=help_i)
75+
help_o = (
76+
"Output format for the table. "
77+
"Defaults to 'pipe' (GitHub compatible). "
78+
"Also supports e.g. 'latex' or 'mediawiki'. "
79+
"See tabulate documentation for full list."
80+
)
81+
parser.add_argument("-o", "--output", help=help_o, default="pipe")
82+
help_s = (
83+
"Columns to add to the table. "
84+
"Accepts a comma-separated list of values. "
85+
f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. "
86+
"Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
87+
"plus any column where not all data points are the same. "
88+
"If the columns are manually specified, then the results for each unique combination of the "
89+
"specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
90+
)
91+
parser.add_argument("-s", "--show", help=help_s)
92+
93+
known_args, unknown_args = parser.parse_known_args()
94+
95+
if unknown_args:
96+
print(f"ERROR: Received unknown args: {unknown_args}.")
97+
print()
98+
parser.print_help()
99+
sys.exit(1)
100+
101+
input_file = known_args.input
102+
if input_file is None and os.path.exists("./llama-bench.sqlite"):
103+
input_file = "llama-bench.sqlite"
104+
if input_file is None:
105+
sqlite_files = glob("*.sqlite")
106+
if len(sqlite_files) == 1:
107+
input_file = sqlite_files[0]
108+
109+
if input_file is None:
110+
print("ERROR: Cannot find a suitable input file, please provide one.")
111+
print()
112+
parser.print_help()
113+
sys.exit(1)
114+
115+
connection = sqlite3.connect(input_file)
116+
cursor = connection.cursor()
117+
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
118+
119+
try:
120+
repo = git.Repo(".", search_parent_directories=True)
121+
except git.exc.InvalidGitRepositoryError:
122+
repo = None
123+
124+
125+
def find_parent_in_data(commit):
126+
"""Helper function to find the most recent parent measured in number of commits for which there is data."""
127+
heap = [(0, commit)]
128+
seen_hexsha8 = set()
129+
while heap:
130+
depth, current_commit = heapq.heappop(heap)
131+
current_hexsha8 = commit.hexsha[:8]
132+
if (current_hexsha8,) in builds:
133+
return current_hexsha8
134+
for parent in commit.parents:
135+
parent_hexsha8 = parent.hexsha[:8]
136+
if parent_hexsha8 not in seen_hexsha8:
137+
seen_hexsha8.add(parent_hexsha8)
138+
heapq.heappush(heap, (depth + 1, parent))
139+
return None
140+
141+
142+
def get_all_parent_hexsha8s(commit):
143+
"""Helper function to recursively get hexsha8 values for all parents of a commit."""
144+
unvisited = [commit]
145+
visited = []
146+
147+
while unvisited:
148+
current_commit = unvisited.pop(0)
149+
visited.append(current_commit.hexsha[:8])
150+
for parent in current_commit.parents:
151+
if parent.hexsha[:8] not in visited:
152+
unvisited.append(parent)
153+
154+
return visited
155+
156+
157+
def get_commit_name(hexsha8):
158+
"""Helper function to find a human-readable name for a commit if possible."""
159+
if repo is None:
160+
return hexsha8
161+
for h in repo.heads:
162+
if h.commit.hexsha[:8] == hexsha8:
163+
return h.name
164+
for t in repo.tags:
165+
if t.commit.hexsha[:8] == hexsha8:
166+
return t.name
167+
return hexsha8
168+
169+
170+
def get_commit_hexsha8(name):
171+
"""Helper function to search for a commit given a human-readable name."""
172+
if repo is None:
173+
return None
174+
for h in repo.heads:
175+
if h.name == name:
176+
return h.commit.hexsha[:8]
177+
for t in repo.tags:
178+
if t.name == name:
179+
return t.commit.hexsha[:8]
180+
return None
181+
182+
183+
hexsha8_baseline = name_baseline = None
184+
185+
# If the user specified a baseline, try to find a commit for it:
186+
if known_args.baseline is not None:
187+
if (known_args.baseline,) in builds:
188+
hexsha8_baseline = known_args.baseline
189+
if hexsha8_baseline is None:
190+
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
191+
name_baseline = known_args.baseline
192+
if hexsha8_baseline is None:
193+
print(f"ERROR: cannot find data for baseline={known_args.baseline}.")
194+
sys.exit(1)
195+
# Otherwise, search for the most recent parent of master for which there is data:
196+
elif repo is not None:
197+
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
198+
199+
if hexsha8_baseline is None:
200+
print("ERROR: No baseline was provided and did not find data for any master branch commits.")
201+
print()
202+
parser.print_help()
203+
sys.exit(1)
204+
else:
205+
print(
206+
"ERROR: No baseline was provided and the current working directory "
207+
"is not part of a git repository from which a baseline could be inferred."
208+
)
209+
print()
210+
parser.print_help()
211+
sys.exit(1)
212+
213+
214+
name_baseline = get_commit_name(hexsha8_baseline)
215+
216+
hexsha8_compare = name_compare = None
217+
218+
# If the user has specified a compare value, try to find a corresponding commit:
219+
if known_args.compare is not None:
220+
if (known_args.compare,) in builds:
221+
hexsha8_compare = known_args.compare
222+
if hexsha8_compare is None:
223+
hexsha8_compare = get_commit_hexsha8(known_args.compare)
224+
name_compare = known_args.compare
225+
if hexsha8_compare is None:
226+
print(f"ERROR: cannot find data for baseline={known_args.compare}.")
227+
sys.exit(1)
228+
# Otherwise, search for the commit for llama-bench was most recently run
229+
# and that is not a parent of master:
230+
elif repo is not None:
231+
hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit)
232+
builds_timestamp = cursor.execute(
233+
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
234+
for (hexsha8, _) in reversed(builds_timestamp):
235+
if hexsha8 not in hexsha8s_master:
236+
hexsha8_compare = hexsha8
237+
break
238+
239+
if hexsha8_compare is None:
240+
print("ERROR: No compare target was provided and did not find data for any non-master commits.")
241+
print()
242+
parser.print_help()
243+
sys.exit(1)
244+
else:
245+
print(
246+
"ERROR: No compare target was provided and the current working directory "
247+
"is not part of a git repository from which a compare target could be inferred."
248+
)
249+
print()
250+
parser.print_help()
251+
sys.exit(1)
252+
253+
name_compare = get_commit_name(hexsha8_compare)
254+
255+
256+
def get_rows(properties):
257+
"""
258+
Helper function that gets table rows for some list of properties.
259+
Rows are created by combining those where all provided properties are equal.
260+
The resulting rows are then grouped by the provided properties and the t/s values are averaged.
261+
The returned rows are unique in terms of property combinations.
262+
"""
263+
select_string = ", ".join(
264+
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
265+
equal_string = " AND ".join(
266+
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
267+
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
268+
)
269+
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt"])
270+
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
271+
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
272+
return cursor.execute(query).fetchall()
273+
274+
275+
# If the user provided columns to group the results by, use them:
276+
if known_args.show is not None:
277+
show = known_args.show.split(",")
278+
unknown_cols = []
279+
for prop in show:
280+
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
281+
unknown_cols.append(prop)
282+
if unknown_cols:
283+
print(f"ERROR: Unknown values for --show: {', '.join(unknown_cols)}")
284+
print()
285+
parser.print_usage()
286+
sys.exit(1)
287+
rows_show = get_rows(show)
288+
# Otherwise, select those columns where the values are not all the same:
289+
else:
290+
rows_full = get_rows(KEY_PROPERTIES)
291+
properties_different = []
292+
for i, kp_i in enumerate(KEY_PROPERTIES):
293+
if kp_i in DEFAULT_SHOW or kp_i == "n_prompt" or kp_i == "n_gen":
294+
continue
295+
for row_full in rows_full:
296+
if row_full[i] != rows_full[0][i]:
297+
properties_different.append(kp_i)
298+
break
299+
300+
show = []
301+
# Show CPU and/or GPU by default even if the hardware for all results is the same:
302+
if "gpu_blas" not in properties_different and "n_gpu_layers" not in properties_different:
303+
gpu_blas = bool(rows_full[0][KEY_PROPERTIES.index("gpu_blas")])
304+
ngl = int(rows_full[0][KEY_PROPERTIES.index("n_gpu_layers")])
305+
306+
if not gpu_blas or ngl != 99 and "cpu_info" not in properties_different:
307+
show.append("cpu_info")
308+
if gpu_blas and "gpu_info" not in properties_different:
309+
show.append("gpu_info")
310+
311+
show += DEFAULT_SHOW
312+
show += properties_different
313+
for prop in DEFAULT_HIDE:
314+
try:
315+
show.remove(prop)
316+
except ValueError:
317+
pass
318+
rows_show = get_rows(show)
319+
320+
table = []
321+
for row in rows_show:
322+
n_prompt = int(row[-4])
323+
n_gen = int(row[-3])
324+
assert n_prompt == 0 or n_gen == 0
325+
test_name = f"tg{n_gen}" if n_prompt == 0 else f"pp{n_prompt}"
326+
# Regular columns test name avg t/s values Speedup
327+
# VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV
328+
table.append(list(row[:-4]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])])
329+
330+
# Some a-posteriori fixes to make the table contents prettier:
331+
for bool_property in BOOL_PROPERTIES:
332+
if bool_property in show:
333+
ip = show.index(bool_property)
334+
for row_table in table:
335+
row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
336+
337+
if "model_size" in show:
338+
ip = show.index("model_size")
339+
for row_table in table:
340+
row_table[ip] = float(row_table[ip]) / 1024 ** 3
341+
342+
if "gpu_info" in show:
343+
ip = show.index("gpu_info")
344+
for gns in GPU_NAME_STRIP:
345+
for row_table in table:
346+
row_table[ip] = row_table[ip].replace(gns, "")
347+
348+
headers = [PRETTY_NAMES[p] for p in show]
349+
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
350+
351+
print(tabulate(
352+
table,
353+
headers=headers,
354+
floatfmt=".2f",
355+
tablefmt=known_args.output
356+
))

0 commit comments

Comments
 (0)