Skip to content

Commit 6a88fad

Browse files
authored
Merge pull request #36 from commit-0/aider
Aider add specification
2 parents e45d8db + 0a00562 commit 6a88fad

File tree

3 files changed

+47
-18
lines changed

3 files changed

+47
-18
lines changed

baselines/commit0_utils.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import git
22
import os
33
import re
4-
import subprocess
54
from dataclasses import asdict
65
from pathlib import Path
76
from typing import List
7+
import fitz
88

99
from baselines.class_types import AgentConfig
1010

@@ -13,7 +13,7 @@
1313
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
1414
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
1515
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
16-
16+
SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n"
1717
# prefix components:
1818
space = " "
1919
branch = "│ "
@@ -122,14 +122,14 @@ def get_target_edit_files(target_dir: str) -> list[str]:
122122
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
123123
HERE')'.
124124
"""
125-
# The grep command
126-
command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}"
127-
128-
# Run the command and capture the output
129-
result = subprocess.run(command, shell=True, capture_output=True, text=True)
130-
131-
# Split the output into lines and remove the base_dir prefix
132-
files = result.stdout.strip().split("\n")
125+
files = []
126+
for root, _, filenames in os.walk(target_dir):
127+
for filename in filenames:
128+
if filename.endswith(".py"):
129+
file_path = os.path.join(root, filename)
130+
with open(file_path, "r") as file:
131+
if "NotImplementedError('IMPLEMENT ME HERE')" in file.read():
132+
files.append(file_path)
133133

134134
# Remove the base_dir prefix
135135
files = [file.replace(target_dir, "").lstrip("/") for file in files]
@@ -143,7 +143,8 @@ def get_target_edit_files(target_dir: str) -> list[str]:
143143
def get_message(
144144
agent_config: AgentConfig,
145145
repo_path: str,
146-
test_dir: str,
146+
test_dir: str | None = None,
147+
test_file: str | None = None,
147148
) -> str:
148149
"""Get the message to Aider."""
149150
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt
@@ -157,6 +158,13 @@ def get_message(
157158
include_stubs=True,
158159
)[: agent_config.max_unit_tests_info_length]
159160
)
161+
elif agent_config.use_unit_tests_info and test_file:
162+
unit_tests_info = (
163+
f"\n{UNIT_TESTS_INFO_HEADER} "
164+
+ get_file_info(
165+
file_path=Path(os.path.join(repo_path, test_file)), prefix=""
166+
)[: agent_config.max_unit_tests_info_length]
167+
)
160168
else:
161169
unit_tests_info = ""
162170

@@ -171,15 +179,34 @@ def get_message(
171179
else:
172180
repo_info = ""
173181

174-
message_to_agent = prompt + repo_info + unit_tests_info
182+
if agent_config.use_spec_info:
183+
spec_info = (
184+
f"\n{SPEC_INFO_HEADER} "
185+
+ get_specification(specification_pdf_path=Path(repo_path, "spec.pdf"))[
186+
: agent_config.max_spec_info_length
187+
]
188+
)
189+
else:
190+
spec_info = ""
191+
192+
message_to_agent = prompt + repo_info + unit_tests_info + spec_info
175193

176194
return message_to_agent
177195

178196

179-
def get_reference(specification_pdf_path: str) -> str:
197+
def get_specification(specification_pdf_path: Path) -> str:
180198
"""Get the reference for a given specification PDF path."""
181199
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
182-
return f"/pdf {specification_pdf_path}"
200+
# Open the specified PDF file
201+
document = fitz.open(specification_pdf_path)
202+
text = ""
203+
204+
# Iterate through the pages
205+
for page_num in range(len(document)):
206+
page = document.load_page(page_num) # loads the specified page
207+
text += page.get_text() # type: ignore
208+
209+
return text
183210

184211

185212
def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:

baselines/run_agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,18 @@ def run_agent_for_repo(
8181
if latest_commit.hexsha != example["base_commit"]:
8282
local_repo.git.reset("--hard", example["base_commit"])
8383
target_edit_files = get_target_edit_files(repo_path)
84-
8584
with DirContext(repo_path):
8685
if commit0_config is None or agent_config is None:
8786
raise ValueError("Invalid input")
8887

89-
message = get_message(agent_config, repo_path, example["test"]["test_dir"])
9088
if agent_config.run_tests:
9189
# when unit test feedback is available, iterate over test files
9290
for test_file in test_files:
9391
test_cmd = f"python -m commit0 test {repo_path} {run_id} {test_file}"
9492
test_file_name = test_file.replace(".py", "").replace("/", "__")
9593
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
9694
lint_cmd = get_lint_cmd(local_repo, agent_config.use_lint_info)
97-
95+
message = get_message(agent_config, repo_path, test_file=test_file)
9896
agent.run(
9997
message,
10098
test_cmd,
@@ -104,6 +102,9 @@ def run_agent_for_repo(
104102
)
105103
else:
106104
# when unit test feedback is not available, iterate over target files to edit
105+
message = get_message(
106+
agent_config, repo_path, test_dir=example["test"]["test_dir"]
107+
)
107108
for f in target_edit_files:
108109
file_name = f.replace(".py", "").replace("/", "__")
109110
log_dir = RUN_AIDER_LOG_DIR / "no_tests" / file_name

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ dependencies = [
1111
"ruff>=0.6.4",
1212
"pre-commit>=3.8.0",
1313
"hydra-core>=1.3.2",
14+
"PyMuPDF>=1.24.5",
15+
"aider-chat>=0.56.0",
1416
"modal>=0.64.95",
1517
"typer>=0.12.0",
16-
"aider-chat",
1718
"datasets>=3.0.0",
1819
"docker>=7.1.0",
1920
"fastcore>=1.7.8",

0 commit comments

Comments
 (0)