Skip to content

Commit 953f4b6

Browse files
committed
add test_fiel support
1 parent 299ecfe commit 953f4b6

File tree

5 files changed

+63
-19
lines changed

5 files changed

+63
-19
lines changed

baselines/baseline_utils.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
1111
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
1212
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
13-
EDIT_HISTORY_HEADER = "\n\n>>> Here is the Edit History:\n"
13+
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
1414

1515
# prefix components:
1616
space = " "
@@ -154,31 +154,54 @@ def get_message_to_aider(
154154
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)
155155

156156
if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
157-
unit_tests_info = f"\n{UNIT_TESTS_INFO_HEADER} " + get_dir_info(
158-
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
159-
prefix="",
160-
include_stubs=True,
157+
unit_tests_info = (
158+
f"\n{UNIT_TESTS_INFO_HEADER} "
159+
+ get_dir_info(
160+
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
161+
prefix="",
162+
include_stubs=True,
163+
)[: aider_config.max_unit_tests_info_length]
161164
)
162165
else:
163166
unit_tests_info = ""
164167

165168
# TODO: assuming we have specification, which we currently do not have
166169
if aider_config.use_reference_info and ds["specification"]:
167-
reference = f"\n{REFERENCE_HEADER} " + get_reference(ds["specification"])
170+
reference = (
171+
f"\n{REFERENCE_HEADER} "
172+
+ get_reference(ds["specification"])[
173+
: aider_config.max_reference_info_length
174+
]
175+
)
168176
else:
169177
reference = ""
178+
170179
if aider_config.use_repo_info:
171-
repo_info = f"\n{REPO_INFO_HEADER} " + get_dir_info(
172-
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
180+
repo_info = (
181+
f"\n{REPO_INFO_HEADER} "
182+
+ get_dir_info(
183+
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
184+
)[: aider_config.max_repo_info_length]
173185
)
174186
else:
175187
repo_info = ""
176188

177-
message_to_aider = prompt + reference + repo_info + unit_tests_info
189+
if aider_config.use_lint_info:
190+
lint_info = (
191+
f"\n{LINT_INFO_HEADER} "
192+
+ subprocess.run(
193+
["pre-commit", "run", "--all-files"], capture_output=True, text=True
194+
).stdout
195+
)[: aider_config.max_lint_info_length]
196+
else:
197+
lint_info = ""
198+
199+
message_to_aider = prompt + reference + repo_info + unit_tests_info + lint_info
178200

179201
return message_to_aider
180202

181203

182-
def get_reference(specification_url: str) -> str:
183-
"""Get the reference for a given specification URL."""
184-
return f"/web {specification_url}"
204+
def get_reference(specification_pdf_path: str) -> str:
205+
"""Get the reference for a given specification PDF path."""
206+
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
207+
return f"/pdf {specification_pdf_path}"

baselines/class_types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@ class Commit0Config(BaseModel):
1111
class AiderConfig(BaseModel):
1212
llm_name: str
1313
use_repo_info: bool
14+
max_repo_info_length: int
1415
use_unit_tests_info: bool
16+
max_unit_tests_info_length: int
1517
use_reference_info: bool
18+
max_reference_info_length: int
19+
use_lint_info: bool
20+
max_lint_info_length: int
1621

1722

1823
class BaselineConfig(BaseModel):
19-
config: Dict[str, Dict[str, Union[str, bool]]]
24+
config: Dict[str, Dict[str, Union[str, bool, int]]]
2025

2126
commit0_config: Commit0Config | None = None
2227
aider_config: AiderConfig | None = None

baselines/config/aider.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defaults:
44
- _self_
55

66
aider_config:
7-
use_repo_info: true
8-
use_unit_tests_info: true
9-
use_reference_info: false
7+
use_repo_info: false
8+
use_unit_tests_info: false
9+
use_reference_info: false
10+
use_lint_info: true

baselines/config/base.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@ commit0_config:
1010
aider_config:
1111
llm_name: "claude-3-5-sonnet-20240620"
1212
use_repo_info: false
13+
max_repo_info_length: 10000
1314
use_unit_tests_info: false
15+
max_unit_tests_info_length: 10000
1416
use_reference_info: false
17+
max_reference_info_length: 10000
18+
use_lint_info: false
19+
max_lint_info_length: 10000
1520

1621
hydra:
1722
run:

baselines/run_aider.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datasets import load_dataset
88
from omegaconf import OmegaConf
99
from tqdm.contrib.concurrent import thread_map
10-
10+
import tarfile
1111
from baselines.baseline_utils import (
1212
get_message_to_aider,
1313
get_target_edit_files_cmd_args,
@@ -46,8 +46,18 @@ def run_aider_for_repo(
4646
# get repo info
4747
_, repo_name = ds["repo"].split("/")
4848

49-
# TODO: assuming we have all test_files, which we currently do not have
50-
test_files = ds["test_files"]
49+
repo_name = repo_name.lower()
50+
repo_name = repo_name.replace(".", "-")
51+
with tarfile.open(f"commit0/data/test_ids/{repo_name}.tar.bz2", "r:bz2") as tar:
52+
for member in tar.getmembers():
53+
if member.isfile():
54+
file = tar.extractfile(member)
55+
if file:
56+
test_files_str = file.read().decode("utf-8")
57+
# print(content.decode("utf-8"))
58+
59+
test_files = test_files_str.split("\n") if isinstance(test_files_str, str) else []
60+
test_files = sorted(list(set([i.split(":")[0] for i in test_files])))
5161

5262
repo_path = os.path.join(commit0_config.base_dir, repo_name)
5363

0 commit comments

Comments
 (0)