Skip to content

Commit d06836e

Browse files
authored
Merge pull request #33 from commit-0/aider_reorg
Aider reorg
2 parents fb74253 + 2850ac0 commit d06836e

File tree

11 files changed

+1232
-982
lines changed

11 files changed

+1232
-982
lines changed

baselines/agents.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
4+
from aider.coders import Coder
5+
from aider.models import Model
6+
from aider.io import InputOutput
7+
8+
9+
class Agents(ABC):
10+
@abstractmethod
11+
def run(self) -> None:
12+
"""Start agent"""
13+
raise NotImplementedError
14+
15+
16+
class AiderAgents(Agents):
17+
def __init__(self, model_name: str):
18+
self.model = Model(model_name)
19+
20+
def run(
21+
self,
22+
message: str,
23+
test_cmd: str,
24+
lint_cmd: str,
25+
fnames: list[str],
26+
log_dir: Path,
27+
) -> None:
28+
"""Start aider agent"""
29+
if test_cmd:
30+
auto_test = True
31+
else:
32+
auto_test = False
33+
if lint_cmd:
34+
auto_lint = True
35+
else:
36+
auto_lint = False
37+
log_dir.mkdir(parents=True, exist_ok=True)
38+
input_history_file = log_dir / ".aider.input.history"
39+
chat_history_file = log_dir / ".aider.chat.history.md"
40+
io = InputOutput(
41+
yes=True,
42+
input_history_file=input_history_file,
43+
chat_history_file=chat_history_file,
44+
)
45+
coder = Coder.create(
46+
main_model=self.model,
47+
fnames=fnames,
48+
auto_lint=auto_lint,
49+
auto_test=auto_test,
50+
lint_cmds=lint_cmd,
51+
test_cmd=test_cmd,
52+
io=io,
53+
)
54+
coder.run(message)

baselines/class_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ class Commit0Config:
1111

1212

1313
@dataclass
14-
class AiderConfig:
15-
llm_name: str
14+
class AgentConfig:
15+
agent_name: str
16+
model_name: str
1617
use_user_prompt: bool
1718
user_prompt: str
1819
use_repo_info: bool

baselines/baseline_utils.py renamed to baselines/commit0_utils.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import re
33
import subprocess
44
from pathlib import Path
5-
from typing import Any, Dict, List
5+
from typing import List
66

7-
from baselines.class_types import AiderConfig
7+
from baselines.class_types import AgentConfig
88

99
PROMPT_HEADER = ">>> Here is the Task:\n"
1010
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
@@ -116,7 +116,7 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
116116
return "\n".join(filter(None, tree_string))
117117

118118

119-
def get_target_edit_files_cmd_args(target_dir: str) -> str:
119+
def get_target_edit_files(target_dir: str) -> list[str]:
120120
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
121121
HERE')'.
122122
"""
@@ -135,54 +135,43 @@ def get_target_edit_files_cmd_args(target_dir: str) -> str:
135135
# Only keep python files
136136
files = [file for file in files if file.endswith(".py")]
137137

138-
return " ".join(files)
138+
return files
139139

140140

141-
def get_message_to_aider(
142-
aider_config: AiderConfig,
143-
target_edit_files_cmd_args: str,
141+
def get_message(
142+
agent_config: AgentConfig,
144143
repo_path: str,
145-
ds: Dict[str, Any],
144+
test_dir: str,
146145
) -> str:
147146
"""Get the message to Aider."""
148-
prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt
147+
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt
149148

150-
if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
149+
if agent_config.use_unit_tests_info and test_dir:
151150
unit_tests_info = (
152151
f"\n{UNIT_TESTS_INFO_HEADER} "
153152
+ get_dir_info(
154-
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
153+
dir_path=Path(os.path.join(repo_path, test_dir)),
155154
prefix="",
156155
include_stubs=True,
157-
)[: aider_config.max_unit_tests_info_length]
156+
)[: agent_config.max_unit_tests_info_length]
158157
)
159158
else:
160159
unit_tests_info = ""
161160

162161
# TODO: assuming we have specification, which we currently do not have
163-
if aider_config.use_reference_info and ds["specification"]:
164-
reference = (
165-
f"\n{REFERENCE_HEADER} "
166-
+ get_reference(ds["specification"])[
167-
: aider_config.max_reference_info_length
168-
]
169-
)
170-
else:
171-
reference = ""
172-
173-
if aider_config.use_repo_info:
162+
if agent_config.use_repo_info:
174163
repo_info = (
175164
f"\n{REPO_INFO_HEADER} "
176165
+ get_dir_info(
177166
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
178-
)[: aider_config.max_repo_info_length]
167+
)[: agent_config.max_repo_info_length]
179168
)
180169
else:
181170
repo_info = ""
182171

183-
message_to_aider = prompt + reference + repo_info + unit_tests_info
172+
message_to_agent = prompt + repo_info + unit_tests_info
184173

185-
return message_to_aider
174+
return message_to_agent
186175

187176

188177
def get_reference(specification_pdf_path: str) -> str:

baselines/configs/aider.yaml renamed to baselines/configs/agent.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ defaults:
33
- base
44
- _self_
55

6-
aider_config:
6+
commit0_config:
7+
repo_split: minitorch
8+
9+
agent_config:
710
use_user_prompt: false
811
use_repo_info: false
912
use_unit_tests_info: false
1013
use_reference_info: false
11-
use_lint_info: true
14+
use_lint_info: false
1215
pre_commit_config_path: .pre-commit-config.yaml
13-
run_tests: true
16+
run_tests: false

baselines/configs/base.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ commit0_config:
1010
repo_split: "simpy"
1111
num_workers: 10
1212

13-
aider_config:
14-
llm_name: "claude-3-5-sonnet-20240620"
13+
agent_config:
14+
agent_name: "aider"
15+
model_name: "claude-3-5-sonnet-20240620"
1516
use_user_prompt: false
16-
user_prompt: "Here is the Task:\n Your task is to iteratively implement the each function that is 'NotImplementedError('IMPLEMENT ME HERE')' in these files until there are no more 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.\nMake sure you read the files carefully.\nYour output should be the edited code files.\nUse the above instructions to modify the supplied files.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nOnly use standard python libraries, do not suggest installing any packages."
17+
user_prompt: "Here is your task:\nYou need to implement all functions with 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."
1718
use_repo_info: false
1819
use_unit_tests_info: false
1920
use_reference_info: false
@@ -27,4 +28,4 @@ aider_config:
2728

2829
hydra:
2930
run:
30-
dir: ./hydra_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
31+
dir: ./hydra_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

baselines/run_agent.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import os
2+
import hydra
3+
from datasets import load_dataset
4+
import traceback
5+
from baselines.commit0_utils import (
6+
get_message,
7+
get_target_edit_files,
8+
)
9+
from baselines.agents import AiderAgents
10+
from typing import Optional, Type
11+
from types import TracebackType
12+
from hydra.core.config_store import ConfigStore
13+
from baselines.class_types import AgentConfig, Commit0Config
14+
from commit0.harness.constants import SPLIT
15+
from commit0.harness.get_pytest_ids import main as get_tests
16+
from commit0.harness.constants import RUN_AIDER_LOG_DIR, RepoInstance
17+
from tqdm import tqdm
18+
from concurrent.futures import ThreadPoolExecutor, as_completed
19+
20+
21+
class DirContext:
22+
def __init__(self, d: str):
23+
self.dir = d
24+
self.cwd = os.getcwd()
25+
26+
def __enter__(self):
27+
os.chdir(self.dir)
28+
29+
def __exit__(
30+
self,
31+
exctype: Optional[Type[BaseException]],
32+
excinst: Optional[BaseException],
33+
exctb: Optional[TracebackType],
34+
) -> None:
35+
os.chdir(self.cwd)
36+
37+
38+
def run_agent_for_repo(
39+
commit0_config: Commit0Config,
40+
agent_config: AgentConfig,
41+
example: RepoInstance,
42+
) -> None:
43+
"""Run Aider for a given repository."""
44+
# get repo info
45+
_, repo_name = example["repo"].split("/")
46+
47+
repo_name = repo_name.lower()
48+
repo_name = repo_name.replace(".", "-")
49+
50+
# Call the commit0 get-tests command to retrieve test files
51+
test_files_str = get_tests(repo_name, stdout=False)
52+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
53+
54+
repo_path = os.path.join(commit0_config.base_dir, repo_name)
55+
repo_path = os.path.abspath(repo_path)
56+
57+
target_edit_files = get_target_edit_files(repo_path)
58+
59+
if agent_config.agent_name == "aider":
60+
agent = AiderAgents(agent_config.model_name)
61+
else:
62+
raise NotImplementedError(
63+
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
64+
)
65+
66+
with DirContext(repo_path):
67+
if commit0_config is None or agent_config is None:
68+
raise ValueError("Invalid input")
69+
70+
message = get_message(agent_config, repo_path, example["test"]["test_dir"])
71+
72+
if agent_config.use_lint_info:
73+
lint_cmd = "pre-commit run --config ../../.pre-commit-config.yaml --files"
74+
else:
75+
lint_cmd = ""
76+
77+
if agent_config.run_tests:
78+
# when unit test feedback is available, iterate over test files
79+
for test_file in test_files:
80+
test_cmd = f"python -m commit0 test {repo_path} {test_file}"
81+
test_file_name = test_file.replace(".py", "").replace("/", "__")
82+
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
83+
84+
agent.run(
85+
message,
86+
test_cmd,
87+
lint_cmd,
88+
target_edit_files,
89+
log_dir,
90+
)
91+
else:
92+
# when unit test feedback is not available, iterate over target files to edit
93+
for f in target_edit_files:
94+
file_name = f.replace(".py", "").replace("/", "__")
95+
log_dir = RUN_AIDER_LOG_DIR / "no_tests" / file_name
96+
97+
agent.run(message, "", lint_cmd, [f], log_dir)
98+
99+
100+
def main() -> None:
101+
"""Main function to run Aider for a given repository.
102+
103+
Will run in parallel for each repo.
104+
"""
105+
cs = ConfigStore.instance()
106+
cs.store(name="user", node=Commit0Config)
107+
cs.store(name="user", node=AgentConfig)
108+
hydra.initialize(version_base=None, config_path="configs")
109+
config = hydra.compose(config_name="agent")
110+
commit0_config = Commit0Config(**config.commit0_config)
111+
agent_config = AgentConfig(**config.agent_config)
112+
113+
dataset = load_dataset(
114+
commit0_config.dataset_name, split=commit0_config.dataset_split
115+
)
116+
filtered_dataset = [
117+
example
118+
for example in dataset
119+
if commit0_config.repo_split == "all"
120+
or (
121+
isinstance(example, dict)
122+
and "repo" in example
123+
and isinstance(example["repo"], str)
124+
and example["repo"].split("/")[-1]
125+
in SPLIT.get(commit0_config.repo_split, [])
126+
)
127+
]
128+
assert len(filtered_dataset) > 0, "No examples available"
129+
130+
with tqdm(
131+
total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos"
132+
) as pbar:
133+
with ThreadPoolExecutor(max_workers=commit0_config.num_workers) as executor:
134+
# Create a future for running Aider for each repo
135+
futures = {
136+
executor.submit(
137+
run_agent_for_repo,
138+
commit0_config,
139+
agent_config,
140+
example, # type: ignore
141+
): example
142+
for example in filtered_dataset
143+
}
144+
# Wait for each future to complete
145+
for future in as_completed(futures):
146+
pbar.update(1)
147+
try:
148+
# Update progress bar, check if Aider ran successfully
149+
future.result()
150+
except Exception:
151+
traceback.print_exc()
152+
continue
153+
154+
155+
if __name__ == "__main__":
156+
main()

0 commit comments

Comments
 (0)