Skip to content

Commit df47275

Browse files
committed
fix bug and pass pyright
1 parent 5bd6b74 commit df47275

File tree

5 files changed

+87
-42
lines changed

5 files changed

+87
-42
lines changed

baselines/baseline_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,13 @@ def get_message_to_aider(
157157
) -> str:
158158
"""Get the message to Aider."""
159159
# support context for aider
160-
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)
160+
if aider_config.use_user_prompt:
161+
assert (
162+
aider_config.user_prompt != ""
163+
), "You choose to use custom user prompt, but it is empty"
164+
prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt
165+
else:
166+
prompt = f"{PROMPT_HEADER} " + get_prompt(target_edit_files_cmd_args)
161167

162168
if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
163169
unit_tests_info = (

baselines/class_types.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from typing import Any, Dict, Union
1+
from dataclasses import dataclass
22

3-
from pydantic import BaseModel
43

5-
6-
class Commit0Config(BaseModel):
4+
@dataclass
5+
class Commit0Config:
76
base_dir: str
87
dataset_name: str
8+
dataset_split: str
99
repo_split: str
10+
num_workers: int
1011

1112

12-
class AiderConfig(BaseModel):
13+
@dataclass
14+
class AiderConfig:
1315
llm_name: str
16+
use_user_prompt: bool
17+
user_prompt: str
1418
use_repo_info: bool
1519
max_repo_info_length: int
1620
use_unit_tests_info: bool
@@ -19,15 +23,4 @@ class AiderConfig(BaseModel):
1923
max_reference_info_length: int
2024
use_lint_info: bool
2125
max_lint_info_length: int
22-
23-
24-
class BaselineConfig(BaseModel):
25-
config: Dict[str, Dict[str, Union[str, bool, int]]]
26-
27-
commit0_config: Commit0Config | None = None
28-
aider_config: AiderConfig | None = None
29-
30-
def model_post_init(self, __context: Any) -> None:
31-
"""Post-initialize the model."""
32-
self.commit0_config = Commit0Config(**self.config["commit0_config"])
33-
self.aider_config = AiderConfig(**self.config["aider_config"])
26+
pre_commit_config_path: str

baselines/config/aider.yaml renamed to baselines/configs/aider.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defaults:
44
- _self_
55

66
aider_config:
7+
use_user_prompt: false
78
use_repo_info: false
89
use_unit_tests_info: false
910
use_reference_info: false

baselines/config/base.yaml renamed to baselines/configs/base.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ defaults:
66
commit0_config:
77
base_dir: /Users/willjiang/Desktop/ai2dev/commit0/repos
88
dataset_name: "wentingzhao/commit0_docstring"
9+
dataset_split: "test"
910
repo_split: "lite"
11+
num_workers: 10
1012

1113
aider_config:
1214
llm_name: "claude-3-5-sonnet-20240620"
15+
use_user_prompt: false
16+
user_prompt: ""
1317
use_repo_info: false
1418
max_repo_info_length: 10000
1519
use_unit_tests_info: false

baselines/run_aider.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
import hydra
66
from datasets import load_dataset
7-
from omegaconf import OmegaConf
8-
import tarfile
7+
import traceback
98
from baselines.baseline_utils import (
109
get_message_to_aider,
1110
get_target_edit_files_cmd_args,
1211
)
13-
from baselines.class_types import AiderConfig, BaselineConfig, Commit0Config
12+
from hydra.core.config_store import ConfigStore
13+
from baselines.class_types import AiderConfig, Commit0Config
1414
from commit0.harness.constants import SPLIT
15-
# from aider.run_aider import get_aider_cmd
15+
from commit0.harness.get_pytest_ids import main as get_tests
16+
from tqdm import tqdm
17+
from concurrent.futures import ThreadPoolExecutor, as_completed
18+
1619

1720
logging.basicConfig(
1821
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -51,16 +54,11 @@ def run_aider_for_repo(
5154

5255
repo_name = repo_name.lower()
5356
repo_name = repo_name.replace(".", "-")
54-
with tarfile.open(f"commit0/data/test_ids/{repo_name}.tar.bz2", "r:bz2") as tar:
55-
for member in tar.getmembers():
56-
if member.isfile():
57-
file = tar.extractfile(member)
58-
if file:
59-
test_files_str = file.read().decode("utf-8")
60-
# print(content.decode("utf-8"))
6157

62-
test_files = test_files_str.split("\n") if isinstance(test_files_str, str) else []
63-
test_files = sorted(list(set([i.split(":")[0] for i in test_files])))
58+
# Call the commit0 get-tests command to retrieve test files
59+
test_files_str = get_tests(repo_name, stdout=True)
60+
61+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
6462

6563
repo_path = os.path.join(commit0_config.base_dir, repo_name)
6664

@@ -112,30 +110,73 @@ def run_aider_for_repo(
112110
logger.error(f"OSError occurred: {e}")
113111

114112

115-
@hydra.main(version_base=None, config_path="config", config_name="aider")
116-
def main(config: BaselineConfig) -> None:
113+
def pre_aider_processing(aider_config: AiderConfig) -> None:
114+
"""Pre-process the Aider config."""
115+
if aider_config.use_user_prompt:
116+
# get user prompt from input
117+
aider_config.user_prompt = input("Enter the user prompt: ")
118+
119+
120+
def main() -> None:
117121
"""Main function to run Aider for a given repository.
118122
119123
Will run in parallel for each repo.
120124
"""
121-
config = BaselineConfig(config=OmegaConf.to_object(config))
122-
commit0_config = config.commit0_config
123-
aider_config = config.aider_config
125+
cs = ConfigStore.instance()
126+
cs.store(name="user", node=Commit0Config)
127+
cs.store(name="user", node=AiderConfig)
128+
129+
hydra.initialize(version_base=None, config_path="configs")
130+
config = hydra.compose(config_name="aider")
131+
132+
commit0_config = Commit0Config(**config.commit0_config)
133+
aider_config = AiderConfig(**config.aider_config)
124134

125135
if commit0_config is None or aider_config is None:
126136
raise ValueError("Invalid input")
127137

128-
dataset = load_dataset(commit0_config.dataset_name, split="test")
138+
dataset = load_dataset(
139+
commit0_config.dataset_name, split=commit0_config.dataset_split
140+
)
129141

130142
filtered_dataset = [
131143
example
132144
for example in dataset
133145
if commit0_config.repo_split == "all"
134-
or example["repo"].split("/")[-1] in SPLIT.get(commit0_config.repo_split, [])
135-
]
136-
137-
for example in filtered_dataset:
138-
run_aider_for_repo(commit0_config, aider_config, example)
146+
or (
147+
isinstance(example, dict)
148+
and "repo" in example
149+
and isinstance(example["repo"], str)
150+
and example["repo"].split("/")[-1]
151+
in SPLIT.get(commit0_config.repo_split, [])
152+
)
153+
][:1]
154+
155+
pre_aider_processing(aider_config)
156+
157+
with tqdm(
158+
total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos"
159+
) as pbar:
160+
with ThreadPoolExecutor(max_workers=commit0_config.num_workers) as executor:
161+
# Create a future for running Aider for each repo
162+
futures = {
163+
executor.submit(
164+
run_aider_for_repo,
165+
commit0_config,
166+
aider_config,
167+
example if isinstance(example, dict) else {},
168+
): example
169+
for example in filtered_dataset
170+
}
171+
# Wait for each future to complete
172+
for future in as_completed(futures):
173+
pbar.update(1)
174+
try:
175+
# Update progress bar, check if Aider ran successfully
176+
future.result()
177+
except Exception:
178+
traceback.print_exc()
179+
continue
139180

140181

141182
if __name__ == "__main__":

0 commit comments

Comments
 (0)