Skip to content

Commit 3860555

Browse files
authored
Merge pull request #46 from commit-0/tweaks
Tweaks
2 parents 4bd2185 + ca81e37 commit 3860555

File tree

6 files changed

+81
-34
lines changed

6 files changed

+81
-34
lines changed

commit0/cli.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import commit0.harness.lint
1111
import commit0.harness.save
1212
from commit0.harness.constants import SPLIT, SPLIT_ALL
13+
from commit0.harness.utils import get_active_branch
1314
import subprocess
1415
import yaml
1516
import os
@@ -245,14 +246,13 @@ def test(
245246

246247
commit0_config = read_commit0_dot_file(commit0_dot_file_path)
247248

248-
if not branch and not reference:
249-
raise typer.BadParameter(
250-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.",
251-
param_hint="BRANCH",
252-
)
253249
if reference:
254250
branch = "reference"
255-
assert branch is not None, "branch is not specified"
251+
if branch is None and not reference:
252+
git_path = os.path.join(
253+
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
254+
)
255+
branch = get_active_branch(git_path)
256256

257257
if verbose == 2:
258258
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
@@ -264,7 +264,7 @@ def test(
264264
commit0_config["dataset_split"],
265265
commit0_config["base_dir"],
266266
repo_or_repo_path,
267-
branch,
267+
branch, # type: ignore
268268
test_ids,
269269
backend,
270270
timeout,
@@ -294,14 +294,8 @@ def evaluate(
294294
) -> None:
295295
"""Evaluate Commit0 split you choose in Setup Stage."""
296296
check_commit0_path()
297-
if not branch and not reference:
298-
raise typer.BadParameter(
299-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name",
300-
param_hint="BRANCH",
301-
)
302297
if reference:
303298
branch = "reference"
304-
assert branch is not None, "branch is not specified"
305299

306300
commit0_config = read_commit0_dot_file(commit0_dot_file_path)
307301
check_valid(commit0_config["repo_split"], SPLIT)

commit0/harness/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class Files(TypedDict):
1616
patch: Dict[str, Path]
1717

1818

19+
BASE_BRANCH = "commit0"
20+
1921
# Constants - Evaluation Log Directories
2022
BASE_IMAGE_BUILD_DIR = Path("logs/build_images/base")
2123
REPO_IMAGE_BUILD_DIR = Path("logs/build_images/repo")

commit0/harness/evaluate.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from concurrent.futures import ThreadPoolExecutor, as_completed
66
from datasets import load_dataset
77
from tqdm import tqdm
8-
from typing import Iterator
8+
from typing import Iterator, Union
99

1010
from commit0.harness.run_pytest_ids import main as run_tests
1111
from commit0.harness.get_pytest_ids import main as get_tests
1212
from commit0.harness.constants import RepoInstance, SPLIT, RUN_PYTEST_LOG_DIR
13-
from commit0.harness.utils import get_hash_string
13+
from commit0.harness.utils import get_hash_string, get_active_branch
1414

1515
logging.basicConfig(
1616
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -23,7 +23,7 @@ def main(
2323
dataset_split: str,
2424
repo_split: str,
2525
base_dir: str,
26-
branch: str,
26+
branch: Union[str, None],
2727
backend: str,
2828
timeout: int,
2929
num_cpus: int,
@@ -32,16 +32,19 @@ def main(
3232
) -> None:
3333
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
3434
repos = SPLIT[repo_split]
35-
pairs = []
35+
triples = []
3636
log_dirs = []
3737
for example in dataset:
3838
repo_name = example["repo"].split("/")[-1]
3939
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
4040
continue
41-
pairs.append((repo_name, example["test"]["test_dir"]))
4241
hashed_test_ids = get_hash_string(example["test"]["test_dir"])
42+
if branch is None:
43+
git_path = os.path.join(base_dir, repo_name)
44+
branch = get_active_branch(git_path)
4345
log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
4446
log_dirs.append(str(log_dir))
47+
triples.append((repo_name, example["test"]["test_dir"], branch))
4548

4649
with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
4750
with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -61,7 +64,7 @@ def main(
6164
rebuild_image=rebuild_image,
6265
verbose=0,
6366
): None
64-
for repo, test_dir in pairs
67+
for repo, test_dir, branch in triples
6568
}
6669
# Wait for each future to complete
6770
for future in as_completed(futures):

commit0/harness/run_pytest_ids.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,30 @@ def main(
8282
)
8383
except Exception as e:
8484
raise e
85+
commit_id = ""
8586
if branch == "reference":
8687
commit_id = example["reference_commit"]
8788
else:
88-
try:
89-
local_repo.git.checkout(branch)
90-
local_branch = local_repo.branches[branch]
91-
commit_id = local_branch.commit.hexsha
92-
except Exception as e:
93-
raise Exception(f"Problem checking out branch {branch}.\n{e}")
89+
# Check if it's a local branch
90+
if branch in local_repo.branches:
91+
commit_id = local_repo.commit(branch).hexsha
92+
else:
93+
found_remote_branch = False
94+
for remote in local_repo.remotes:
95+
remote.fetch() # Fetch latest updates from each remote
96+
97+
# Check if the branch exists in this remote
98+
for ref in remote.refs:
99+
if (
100+
ref.remote_head == branch
101+
): # Compare branch name without remote prefix
102+
commit_id = local_repo.commit(ref.name).hexsha
103+
found_remote_branch = True
104+
break # Branch found, no need to keep checking this remote
105+
if found_remote_branch:
106+
break # Stop checking other remotes if branch is found
107+
if not found_remote_branch:
108+
raise Exception(f"Branch {branch} does not exist locally or remotely.")
94109
patch = generate_patch_between_commits(
95110
local_repo, example["base_commit"], commit_id
96111
)

commit0/harness/setup.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from commit0.harness.utils import (
88
clone_repo,
99
)
10-
from commit0.harness.constants import RepoInstance, SPLIT
10+
from commit0.harness.constants import BASE_BRANCH, RepoInstance, SPLIT
1111

1212

1313
logging.basicConfig(
@@ -29,7 +29,12 @@ def main(
2929
continue
3030
clone_url = f"https://github.com/{example['repo']}.git"
3131
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
32-
clone_repo(clone_url, clone_dir, example["base_commit"], logger)
32+
branch = dataset_name.split("/")[-1]
33+
repo = clone_repo(clone_url, clone_dir, branch, logger)
34+
if BASE_BRANCH in repo.branches:
35+
repo.git.branch("-d", BASE_BRANCH)
36+
repo.git.checkout("-b", BASE_BRANCH)
37+
logger.info("Checked out the base commit: commit 0")
3338

3439

3540
__all__ = []

commit0/harness/utils.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
import sys
88
from pathlib import Path
9-
from typing import Optional
9+
from typing import Optional, Union
1010

1111
from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
1212
from ghapi.core import GhApi
@@ -85,7 +85,7 @@ def extract_test_output(ss: str, pattern: str) -> str:
8585

8686

8787
def clone_repo(
88-
clone_url: str, clone_dir: str, commit: str, logger: logging.Logger
88+
clone_url: str, clone_dir: str, branch: str, logger: logging.Logger
8989
) -> git.Repo:
9090
"""Clone repo into the specified directory if it does not already exist.
9191
@@ -98,8 +98,8 @@ def clone_repo(
9898
URL of the repository to clone.
9999
clone_dir : str
100100
Directory where the repository will be cloned.
101-
commit : str
102-
The commit hash or branch/tag name to checkout.
101+
branch : str
102+
The branch/tag name to checkout.
103103
logger : logging.Logger
104104
The logger object.
105105
@@ -129,11 +129,10 @@ def clone_repo(
129129
except git.exc.GitCommandError as e:
130130
raise RuntimeError(f"Failed to clone repository: {e}")
131131

132-
logger.info(f"Checking out {commit}")
133132
try:
134-
repo.git.checkout(commit)
133+
repo.git.checkout(branch)
135134
except git.exc.GitCommandError as e:
136-
raise RuntimeError(f"Failed to check out {commit}: {e}")
135+
raise RuntimeError(f"Failed to check out {branch}: {e}")
137136

138137
return repo
139138

@@ -190,4 +189,33 @@ def generate_patch_between_commits(
190189
raise Exception(f"Error generating patch: {e}")
191190

192191

192+
def get_active_branch(repo_path: Union[str, Path]) -> str:
193+
"""Retrieve the current active branch of a Git repository.
194+
195+
Args:
196+
----
197+
repo_path (Path): The path to git repo.
198+
199+
Returns:
200+
-------
201+
str: The name of the active branch.
202+
203+
Raises:
204+
------
205+
Exception: If the repository is in a detached HEAD state.
206+
207+
"""
208+
repo = git.Repo(repo_path)
209+
try:
210+
# Get the current active branch
211+
branch = repo.active_branch.name
212+
except TypeError as e:
213+
raise Exception(
214+
f"{e}\nThis means the repository is in a detached HEAD state. "
215+
"To proceed, please specify a valid branch by using --branch {branch}."
216+
)
217+
218+
return branch
219+
220+
193221
__all__ = []

0 commit comments

Comments
 (0)