Skip to content

Commit 0b60e82

Browse files
authored
Merge pull request #34 from commit-0/integration
better integration between commit0 and baseline
2 parents a4d47af + cd4c3f7 commit 0b60e82

File tree

13 files changed

+155
-65
lines changed

13 files changed

+155
-65
lines changed

.github/workflows/system.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ jobs:
3232
env:
3333
GITHUB_TOKEN: ${{ secrets.MY_GITHUB_TOKEN }}
3434
run: |
35-
uv run commit0 save simpy test-save-commit0
35+
uv run commit0 save simpy test-save-commit0 master

baselines/class_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ class AgentConfig:
2020
max_repo_info_length: int
2121
use_unit_tests_info: bool
2222
max_unit_tests_info_length: int
23-
use_reference_info: bool
24-
max_reference_info_length: int
23+
use_spec_info: bool
24+
max_spec_info_length: int
2525
use_lint_info: bool
2626
max_lint_info_length: int
2727
pre_commit_config_path: str

baselines/commit0_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import git
12
import os
23
import re
34
import subprocess
5+
from dataclasses import asdict
46
from pathlib import Path
57
from typing import List
68

@@ -178,3 +180,65 @@ def get_reference(specification_pdf_path: str) -> str:
178180
"""Get the reference for a given specification PDF path."""
179181
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
180182
return f"/pdf {specification_pdf_path}"
183+
184+
185+
def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:
186+
"""Create a new branch or switch to an existing branch.
187+
188+
Parameters
189+
----------
190+
repo : git.Repo
191+
The repository object.
192+
branch : str
193+
The name of the branch to create or switch to.
194+
from_commit : str
195+
from which commit to create the branch
196+
197+
Returns
198+
-------
199+
None
200+
201+
Raises
202+
------
203+
RuntimeError
204+
If creating or switching to the branch fails.
205+
206+
"""
207+
try:
208+
# Check if the branch already exists
209+
if branch in repo.heads:
210+
repo.git.checkout(branch)
211+
else:
212+
repo.git.checkout(from_commit)
213+
repo.git.checkout("-b", branch)
214+
except git.exc.GitCommandError as e: # type: ignore
215+
raise RuntimeError(f"Failed to create or switch to branch '{branch}': {e}")
216+
217+
218+
def args2string(agent_config: AgentConfig) -> str:
219+
"""Converts specific fields from an `AgentConfig` object into a formatted string.
220+
221+
Args:
222+
----
223+
agent_config (AgentConfig): A dataclass object containing configuration
224+
options for an agent.
225+
226+
Returns:
227+
-------
228+
str: A string representing the selected key-value pairs from the `AgentConfig`
229+
object, joined by double underscores.
230+
231+
"""
232+
arg_dict = asdict(agent_config)
233+
result_list = []
234+
keys_to_collect = ["model_name", "run_tests", "use_lint_info", "use_spec_info"]
235+
for key in keys_to_collect:
236+
value = arg_dict[key]
237+
if isinstance(value, bool):
238+
if value:
239+
value = "1"
240+
else:
241+
value = "0"
242+
result_list.append(f"{key}-{value}")
243+
concatenated_string = "__".join(result_list)
244+
return concatenated_string

baselines/configs/agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ agent_config:
1010
use_user_prompt: false
1111
use_repo_info: false
1212
use_unit_tests_info: false
13-
use_reference_info: false
13+
use_spec_info: false
1414
use_lint_info: false
1515
pre_commit_config_path: .pre-commit-config.yaml
1616
run_tests: false

baselines/configs/base.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ agent_config:
1717
user_prompt: "Here is your task:\nYou need to implement all functions with 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."
1818
use_repo_info: false
1919
use_unit_tests_info: false
20-
use_reference_info: false
20+
use_spec_info: false
2121
use_lint_info: false
2222
pre_commit_config_path: .pre-commit-config.yaml
2323
run_tests: True
2424
max_repo_info_length: 10000
2525
max_unit_tests_info_length: 10000
26-
max_reference_info_length: 10000
26+
max_spec_info_length: 10000
2727
max_lint_info_length: 10000
2828
max_iteration: 3
2929

baselines/run_agent.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
22
import sys
33
import hydra
4-
from datasets import load_dataset
54
import traceback
5+
from datasets import load_dataset
6+
from git import Repo
67
from baselines.commit0_utils import (
8+
args2string,
9+
create_branch,
710
get_message,
811
get_target_edit_files,
912
)
@@ -54,6 +57,12 @@ def run_agent_for_repo(
5457

5558
repo_path = os.path.join(commit0_config.base_dir, repo_name)
5659
repo_path = os.path.abspath(repo_path)
60+
try:
61+
local_repo = Repo(repo_path)
62+
except Exception:
63+
raise Exception(
64+
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
65+
)
5766

5867
target_edit_files = get_target_edit_files(repo_path)
5968

@@ -64,6 +73,16 @@ def run_agent_for_repo(
6473
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
6574
)
6675

76+
run_id = args2string(agent_config)
77+
print(f"Agent is coding on branch: {run_id}", file=sys.stderr)
78+
create_branch(local_repo, run_id, example["base_commit"])
79+
latest_commit = local_repo.commit(run_id)
80+
# in cases where the latest commit of branch is not commit 0
81+
# set it back to commit 0
82+
# TODO: ask user for permission
83+
if latest_commit.hexsha != example["base_commit"]:
84+
local_repo.git.reset("--hard", example["base_commit"])
85+
6786
with DirContext(repo_path):
6887
if commit0_config is None or agent_config is None:
6988
raise ValueError("Invalid input")
@@ -78,7 +97,7 @@ def run_agent_for_repo(
7897
if agent_config.run_tests:
7998
# when unit test feedback is available, iterate over test files
8099
for test_file in test_files:
81-
test_cmd = f"python -m commit0 test {repo_path} {test_file}"
100+
test_cmd = f"python -m commit0 test {repo_path} {run_id} {test_file}"
82101
test_file_name = test_file.replace(".py", "").replace("/", "__")
83102
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
84103

commit0/__main__.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,21 @@ def main() -> None:
4646
config.base_dir = os.path.abspath(config.base_dir)
4747

4848
if command == "clone":
49+
if len(sys.argv) != 3:
50+
raise ValueError(
51+
"You provided an incorrect number of arguments.\nUsage: commit0 clone {repo_split}"
52+
)
4953
commit0.harness.setup.main(
5054
config.dataset_name,
5155
config.dataset_split,
5256
config.repo_split,
5357
config.base_dir,
54-
config.branch,
5558
)
5659
elif command == "build":
60+
if len(sys.argv) != 3:
61+
raise ValueError(
62+
"You provided an incorrect number of arguments.\nUsage: commit0 build {repo_split}"
63+
)
5764
commit0.harness.build.main(
5865
config.dataset_name,
5966
config.dataset_split,
@@ -62,20 +69,37 @@ def main() -> None:
6269
config.backend,
6370
)
6471
elif command == "get-tests":
72+
if len(sys.argv) != 3:
73+
raise ValueError(
74+
"You provided an incorrect number of arguments.\nUsage: commit0 get-tests {repo_name}"
75+
)
6576
repo = sys.argv[2]
6677
commit0.harness.get_pytest_ids.main(repo, stdout=True)
6778
elif command == "test" or command == "test-reference":
6879
# this command assume execution in arbitrary working directory
6980
repo_or_repo_path = sys.argv[2]
70-
test_ids = sys.argv[3]
7181
if command == "test-reference":
72-
config.branch = "reference"
82+
if len(sys.argv) != 4:
83+
raise ValueError(
84+
"You provided an incorrect number of arguments.\nUsage: commit0 test-reference {repo_dir} {test_ids}"
85+
)
86+
branch = "reference"
87+
test_ids = sys.argv[3]
88+
else:
89+
if len(sys.argv) != 5:
90+
raise ValueError(
91+
"You provided an incorrect number of arguments.\nUsage: commit0 test {repo_dir} {branch} {test_ids}"
92+
)
93+
branch = sys.argv[3]
94+
test_ids = sys.argv[4]
95+
if branch.startswith("branch="):
96+
branch = branch[len("branch=") :]
7397
commit0.harness.run_pytest_ids.main(
7498
config.dataset_name,
7599
config.dataset_split,
76100
config.base_dir,
77101
repo_or_repo_path,
78-
config.branch,
102+
branch,
79103
test_ids,
80104
config.backend,
81105
config.timeout,
@@ -84,27 +108,46 @@ def main() -> None:
84108
)
85109
elif command == "evaluate" or command == "evaluate-reference":
86110
if command == "evaluate-reference":
87-
config.branch = "reference"
111+
if len(sys.argv) != 3:
112+
raise ValueError(
113+
"You provided an incorrect number of arguments.\nUsage: commit0 evaluate-reference {repo_split}"
114+
)
115+
branch = "reference"
116+
else:
117+
if len(sys.argv) != 4:
118+
raise ValueError(
119+
"You provided an incorrect number of arguments.\nUsage: commit0 evaluate {repo_split} {branch}"
120+
)
121+
branch = sys.argv[3]
122+
if branch.startswith("branch="):
123+
branch = branch[len("branch=") :]
88124
commit0.harness.evaluate.main(
89125
config.dataset_name,
90126
config.dataset_split,
91127
config.repo_split,
92128
config.base_dir,
93-
config.branch,
129+
branch,
94130
config.backend,
95131
config.timeout,
96132
config.num_cpus,
97133
config.num_workers,
98134
)
99135
elif command == "save":
100-
organization = sys.argv[3]
136+
if len(sys.argv) != 5:
137+
raise ValueError(
138+
"You provided an incorrect number of arguments.\nUsage: commit0 save {repo_split} {owner} {branch}"
139+
)
140+
owner = sys.argv[3]
141+
branch = sys.argv[4]
142+
if branch.startswith("branch="):
143+
branch = branch[len("branch=") :]
101144
commit0.harness.save.main(
102145
config.dataset_name,
103146
config.dataset_split,
104147
config.repo_split,
105148
config.base_dir,
106-
organization,
107-
config.branch,
149+
owner,
150+
branch,
108151
config.github_token,
109152
)
110153

commit0/configs/base.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ num_workers: 8
1414

1515
# test related
1616
backend: local
17-
branch: ai
1817
timeout: 1_800
1918
num_cpus: 1
2019

commit0/configs/config_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ class Commit0Config:
1818

1919
# test related
2020
backend: str
21-
# which branch to work on
22-
branch: str
2321
# timeout for running pytest
2422
timeout: int
2523
num_cpus: int

commit0/harness/run_pytest_ids.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ def main(
7777
local_repo = git.Repo(repo_dir)
7878
except git.exc.NoSuchPathError: # type: ignore
7979
raise Exception(
80-
f"{repo_dir} and {repo_or_repo_dir} are not git directories.\nUsage: commit0 test {{repo_dir}} {test_ids}"
80+
f"{repo_dir} and {repo_or_repo_dir} are not git directories.\nUsage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
8181
)
8282
except Exception as e:
8383
raise e
8484
if branch == "reference":
8585
commit_id = example["reference_commit"]
8686
else:
87+
if branch not in local_repo.branches:
88+
raise Exception(f"Branch {branch} does not exist.")
8789
local_branch = local_repo.branches[branch]
8890
commit_id = local_branch.commit.hexsha
8991
patch = generate_patch_between_commits(

commit0/harness/save.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def main(
2020
dataset_split: str,
2121
repo_split: str,
2222
base_dir: str,
23-
organization: str,
23+
owner: str,
2424
branch: str,
2525
github_token: str,
2626
) -> None:
@@ -33,7 +33,7 @@ def main(
3333
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
3434
continue
3535
local_repo_path = f"{base_dir}/{repo_name}"
36-
github_repo_url = f"https://github.com/{organization}/{repo_name}.git"
36+
github_repo_url = f"https://github.com/{owner}/{repo_name}.git"
3737
github_repo_url = github_repo_url.replace(
3838
"https://", f"https://x-access-token:{github_token}@"
3939
)
@@ -46,7 +46,7 @@ def main(
4646

4747
# create Github repo
4848
create_repo_on_github(
49-
organization=organization, repo=repo_name, logger=logger, token=github_token
49+
organization=owner, repo=repo_name, logger=logger, token=github_token
5050
)
5151
# Add your remote repository URL
5252
remote_name = "progress-tracker"
@@ -75,9 +75,7 @@ def main(
7575
origin.push(refspec=f"{branch}:{branch}")
7676
logger.info(f"Pushed to {github_repo_url} on branch {branch}")
7777
except Exception as e:
78-
raise Exception(
79-
f"Push {branch} to {organization}/{repo_name} fails.\n{str(e)}"
80-
)
78+
raise Exception(f"Push {branch} to {owner}/{repo_name} fails.\n{str(e)}")
8179

8280

8381
__all__ = []

commit0/harness/setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Iterator
77
from commit0.harness.utils import (
88
clone_repo,
9-
create_branch,
109
)
1110
from commit0.harness.constants import RepoInstance, SPLIT
1211

@@ -18,7 +17,10 @@
1817

1918

2019
def main(
21-
dataset_name: str, dataset_split: str, repo_split: str, base_dir: str, branch: str
20+
dataset_name: str,
21+
dataset_split: str,
22+
repo_split: str,
23+
base_dir: str,
2224
) -> None:
2325
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2426
for example in dataset:
@@ -27,8 +29,7 @@ def main(
2729
continue
2830
clone_url = f"https://github.com/{example['repo']}.git"
2931
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
30-
local_repo = clone_repo(clone_url, clone_dir, example["base_commit"], logger)
31-
create_branch(local_repo, branch, logger)
32+
clone_repo(clone_url, clone_dir, example["base_commit"], logger)
3233

3334

3435
__all__ = []

0 commit comments

Comments
 (0)