Skip to content

scripts: fix compare-llama-bench commit hash logic #11891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions scripts/compare-llama-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,22 @@

connection = sqlite3.connect(input_file)
cursor = connection.cursor()

build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]

if build_len_min != build_len_max:
logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
"Try purging the the database of old commits.")
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")

build_len: int = build_len_min

builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]

commit_short_len = len(builds[0][0])
if not builds:
raise RuntimeError(f"{input_file} does not contain any builds.")

try:
repo = git.Repo(".", search_parent_directories=True)
Expand All @@ -140,11 +153,11 @@ def find_parent_in_data(commit: git.Commit):
seen_hexsha8 = set()
while heap:
depth, current_commit = heapq.heappop(heap)
current_hexsha8 = commit.hexsha[:commit_short_len]
if (current_hexsha8,) in builds:
current_hexsha8 = commit.hexsha[:build_len]
if current_hexsha8 in builds:
return current_hexsha8
for parent in commit.parents:
parent_hexsha8 = parent.hexsha[:commit_short_len]
parent_hexsha8 = parent.hexsha[:build_len]
if parent_hexsha8 not in seen_hexsha8:
seen_hexsha8.add(parent_hexsha8)
heapq.heappush(heap, (depth + 1, parent))
Expand All @@ -158,48 +171,48 @@ def get_all_parent_hexsha8s(commit: git.Commit):

while unvisited:
current_commit = unvisited.pop(0)
visited.append(current_commit.hexsha[:commit_short_len])
visited.append(current_commit.hexsha[:build_len])
for parent in current_commit.parents:
if parent.hexsha[:commit_short_len] not in visited:
if parent.hexsha[:build_len] not in visited:
unvisited.append(parent)

return visited


def get_commit_name(hexsha8):
def get_commit_name(hexsha8: str):
"""Helper function to find a human-readable name for a commit if possible."""
if repo is None:
return hexsha8
for h in repo.heads:
if h.commit.hexsha[:commit_short_len] == hexsha8:
if h.commit.hexsha[:build_len] == hexsha8:
return h.name
for t in repo.tags:
if t.commit.hexsha[:commit_short_len] == hexsha8:
if t.commit.hexsha[:build_len] == hexsha8:
return t.name
return hexsha8


def get_commit_hexsha8(name):
def get_commit_hexsha8(name: str):
"""Helper function to search for a commit given a human-readable name."""
if repo is None:
return None
for h in repo.heads:
if h.name == name:
return h.commit.hexsha[:commit_short_len]
return h.commit.hexsha[:build_len]
for t in repo.tags:
if t.name == name:
return t.commit.hexsha[:commit_short_len]
return t.commit.hexsha[:build_len]
for c in repo.iter_commits("--all"):
if c.hexsha[:commit_short_len] == name[:commit_short_len]:
return c.hexsha[:commit_short_len]
if c.hexsha[:build_len] == name[:build_len]:
return c.hexsha[:build_len]
return None


hexsha8_baseline = name_baseline = None

# If the user specified a baseline, try to find a commit for it:
if known_args.baseline is not None:
if (known_args.baseline,) in builds:
if known_args.baseline in builds:
hexsha8_baseline = known_args.baseline
if hexsha8_baseline is None:
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
Expand Down Expand Up @@ -228,7 +241,7 @@ def get_commit_hexsha8(name):

# If the user has specified a compare value, try to find a corresponding commit:
if known_args.compare is not None:
if (known_args.compare,) in builds:
if known_args.compare in builds:
hexsha8_compare = known_args.compare
if hexsha8_compare is None:
hexsha8_compare = get_commit_hexsha8(known_args.compare)
Expand Down