|
4 | 4 |
|
5 | 5 | import hydra
|
6 | 6 | from datasets import load_dataset
|
7 |
| -from omegaconf import OmegaConf |
8 |
| -import tarfile |
| 7 | +import traceback |
9 | 8 | from baselines.baseline_utils import (
|
10 | 9 | get_message_to_aider,
|
11 | 10 | get_target_edit_files_cmd_args,
|
12 | 11 | )
|
13 |
| -from baselines.class_types import AiderConfig, BaselineConfig, Commit0Config |
| 12 | +from hydra.core.config_store import ConfigStore |
| 13 | +from baselines.class_types import AiderConfig, Commit0Config |
14 | 14 | from commit0.harness.constants import SPLIT
|
15 |
| -# from aider.run_aider import get_aider_cmd |
| 15 | +from commit0.harness.get_pytest_ids import main as get_tests |
| 16 | +from tqdm import tqdm |
| 17 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 18 | + |
16 | 19 |
|
17 | 20 | logging.basicConfig(
|
18 | 21 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
@@ -51,16 +54,11 @@ def run_aider_for_repo(
|
51 | 54 |
|
52 | 55 | repo_name = repo_name.lower()
|
53 | 56 | repo_name = repo_name.replace(".", "-")
|
54 |
| - with tarfile.open(f"commit0/data/test_ids/{repo_name}.tar.bz2", "r:bz2") as tar: |
55 |
| - for member in tar.getmembers(): |
56 |
| - if member.isfile(): |
57 |
| - file = tar.extractfile(member) |
58 |
| - if file: |
59 |
| - test_files_str = file.read().decode("utf-8") |
60 |
| - # print(content.decode("utf-8")) |
61 | 57 |
|
62 |
| - test_files = test_files_str.split("\n") if isinstance(test_files_str, str) else [] |
63 |
| - test_files = sorted(list(set([i.split(":")[0] for i in test_files]))) |
| 58 | + # Call the commit0 get-tests command to retrieve test files |
| 59 | + test_files_str = get_tests(repo_name, stdout=True) |
| 60 | + |
| 61 | + test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) |
64 | 62 |
|
65 | 63 | repo_path = os.path.join(commit0_config.base_dir, repo_name)
|
66 | 64 |
|
@@ -112,30 +110,73 @@ def run_aider_for_repo(
|
112 | 110 | logger.error(f"OSError occurred: {e}")
|
113 | 111 |
|
114 | 112 |
|
115 |
| -@hydra.main(version_base=None, config_path="config", config_name="aider") |
116 |
| -def main(config: BaselineConfig) -> None: |
| 113 | +def pre_aider_processing(aider_config: AiderConfig) -> None: |
| 114 | + """Pre-process the Aider config.""" |
| 115 | + if aider_config.use_user_prompt: |
| 116 | + # get user prompt from input |
| 117 | + aider_config.user_prompt = input("Enter the user prompt: ") |
| 118 | + |
| 119 | + |
| 120 | +def main() -> None: |
117 | 121 | """Main function to run Aider for a given repository.
|
118 | 122 |
|
119 | 123 | Will run in parallel for each repo.
|
120 | 124 | """
|
121 |
| - config = BaselineConfig(config=OmegaConf.to_object(config)) |
122 |
| - commit0_config = config.commit0_config |
123 |
| - aider_config = config.aider_config |
| 125 | + cs = ConfigStore.instance() |
| 126 | + cs.store(name="user", node=Commit0Config) |
| 127 | + cs.store(name="user", node=AiderConfig) |
| 128 | + |
| 129 | + hydra.initialize(version_base=None, config_path="configs") |
| 130 | + config = hydra.compose(config_name="aider") |
| 131 | + |
| 132 | + commit0_config = Commit0Config(**config.commit0_config) |
| 133 | + aider_config = AiderConfig(**config.aider_config) |
124 | 134 |
|
125 | 135 | if commit0_config is None or aider_config is None:
|
126 | 136 | raise ValueError("Invalid input")
|
127 | 137 |
|
128 |
| - dataset = load_dataset(commit0_config.dataset_name, split="test") |
| 138 | + dataset = load_dataset( |
| 139 | + commit0_config.dataset_name, split=commit0_config.dataset_split |
| 140 | + ) |
129 | 141 |
|
130 | 142 | filtered_dataset = [
|
131 | 143 | example
|
132 | 144 | for example in dataset
|
133 | 145 | if commit0_config.repo_split == "all"
|
134 |
| - or example["repo"].split("/")[-1] in SPLIT.get(commit0_config.repo_split, []) |
135 |
| - ] |
136 |
| - |
137 |
| - for example in filtered_dataset: |
138 |
| - run_aider_for_repo(commit0_config, aider_config, example) |
| 146 | + or ( |
| 147 | + isinstance(example, dict) |
| 148 | + and "repo" in example |
| 149 | + and isinstance(example["repo"], str) |
| 150 | + and example["repo"].split("/")[-1] |
| 151 | + in SPLIT.get(commit0_config.repo_split, []) |
| 152 | + ) |
| 153 | + ][:1] |
| 154 | + |
| 155 | + pre_aider_processing(aider_config) |
| 156 | + |
| 157 | + with tqdm( |
| 158 | + total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos" |
| 159 | + ) as pbar: |
| 160 | + with ThreadPoolExecutor(max_workers=commit0_config.num_workers) as executor: |
| 161 | + # Create a future for running Aider for each repo |
| 162 | + futures = { |
| 163 | + executor.submit( |
| 164 | + run_aider_for_repo, |
| 165 | + commit0_config, |
| 166 | + aider_config, |
| 167 | + example if isinstance(example, dict) else {}, |
| 168 | + ): example |
| 169 | + for example in filtered_dataset |
| 170 | + } |
| 171 | + # Wait for each future to complete |
| 172 | + for future in as_completed(futures): |
| 173 | + pbar.update(1) |
| 174 | + try: |
| 175 | + # Update progress bar, check if Aider ran successfully |
| 176 | + future.result() |
| 177 | + except Exception: |
| 178 | + traceback.print_exc() |
| 179 | + continue |
139 | 180 |
|
140 | 181 |
|
141 | 182 | if __name__ == "__main__":
|
|
0 commit comments