|
| 1 | +"""This is the harness for running an AI agent on the SWE Bench dataset.""" |
| 2 | + |
| 3 | +#!/usr/bin/env python |
| 4 | +import json |
| 5 | +import pprint |
| 6 | +import random |
| 7 | +import subprocess |
| 8 | +import sys |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import lox |
| 12 | + |
| 13 | +from codegen import Codebase |
| 14 | +from codegen.agents.code_agent import CodeAgent |
| 15 | +from codegen.configs.models.codebase import CodebaseConfig |
| 16 | +from codegen.extensions.swebench.utils import ( |
| 17 | + SweBenchExample, |
| 18 | + get_swe_bench_examples, |
| 19 | + load_predictions, |
| 20 | +) |
| 21 | + |
| 22 | +PARENT_DIR = Path(__file__).parent |
| 23 | + |
| 24 | +PREDS_DNAME = PARENT_DIR / "predictions" |
| 25 | + |
| 26 | + |
| 27 | +def diff_versus_commit(git_dname, commit): |
| 28 | + """Take a diff of `git_dname` current contents versus the `commit`.""" |
| 29 | + diff_cmd = f"git -C {git_dname} diff {commit}" |
| 30 | + diff_output = subprocess.check_output(diff_cmd.split()).decode() |
| 31 | + return diff_output |
| 32 | + |
| 33 | + |
| 34 | +def files_in_patch(patch): |
| 35 | + """Extract the list of modified files from a unified diff patch string.""" |
| 36 | + files = [] |
| 37 | + for line in patch.split("\n"): |
| 38 | + if line.startswith("--- a/") or line.startswith("+++ b/"): |
| 39 | + fname = line.split("/", 1)[1] |
| 40 | + if fname not in files: |
| 41 | + files.append(fname) |
| 42 | + return files |
| 43 | + |
| 44 | + |
| 45 | +def show_problems(dataset): |
| 46 | + """Print out all the instance_id and problem_descriptions.""" |
| 47 | + for inst, entry in dataset.items(): |
| 48 | + problem = entry.problem_statement.splitlines()[0] |
| 49 | + print(f"{inst}: {problem}") |
| 50 | + |
| 51 | + |
| 52 | +def run_agent_on_entry(entry: SweBenchExample, model: str, codebase: Codebase | None = None, run_id: str | None = None): |
| 53 | + """Process one `entry` from SWE Bench using the LLM `models` at the |
| 54 | + given `temperature`. Set `model_name_or_path` in the result json. |
| 55 | + """ |
| 56 | + instance_id = entry.instance_id |
| 57 | + base_commit = entry.base_commit |
| 58 | + |
| 59 | + print("=" * 60) |
| 60 | + pprint.pprint(instance_id) |
| 61 | + print("=" * 60) |
| 62 | + problem_statement = entry.problem_statement |
| 63 | + print(problem_statement) |
| 64 | + |
| 65 | + gold_files = files_in_patch(entry.patch) |
| 66 | + |
| 67 | + if codebase is None: |
| 68 | + config = CodebaseConfig( |
| 69 | + disable_file_parse=True, # Disable the graph AND disable file parsing (file.edit only) |
| 70 | + ) |
| 71 | + codebase = Codebase.from_repo(repo_full_name=entry.repo, commit=base_commit, language="python", config=config) # check out the repo |
| 72 | + |
| 73 | + metadata = {"run_id": run_id, "instance_id": instance_id, "difficulty": f"difficulty_{entry.difficulty}"} |
| 74 | + tags = [str(value) for value in metadata.values()] |
| 75 | + agent = CodeAgent(codebase=codebase, tags=tags, metadata=metadata) |
| 76 | + |
| 77 | + pprint.pprint(instance_id) |
| 78 | + pprint.pprint(gold_files) |
| 79 | + |
| 80 | + message = """Below is a real GitHub issue from a popular GitHub repository. |
| 81 | +The issue was filed some time ago. |
| 82 | +The repo has been checked out at the commit that existed at the moment the issue was filed. |
| 83 | +If you are already familiar with this repo, be cautious! |
| 84 | +You are working with an old version of the repo! |
| 85 | +Filenames, directory names, file contents, etc may be different than what you're used to. |
| 86 | +
|
| 87 | +Propose changes to update the repo to fix the problem below. |
| 88 | +*** IMPORTANT: *** DO NOT MODIFY ANY TESTS! |
| 89 | +*** IMPORTANT: *** DO NOT ADD ANY TESTS! |
| 90 | +
|
| 91 | +Before commiting to do any modifications, double check your work with the Reflection tool. |
| 92 | +you can also use that tool to check your work after you think you are done. |
| 93 | +if you ever get stuck using other tools, use the Reflection tool to re asses your situation. |
| 94 | +after every file edit, use the Reflection tool to check your work and sanity check yourself. |
| 95 | +after editing a file you need to double check your work and use the ViewFiles tool to make sure you didn't break anything and that your edits are indeed correct. |
| 96 | +
|
| 97 | +You should follow the advices of the Reflection tool when ever they seem reasonable. |
| 98 | +
|
| 99 | +Also DO NOT ADD OR EDIT ANY TESTS! |
| 100 | +
|
| 101 | +""" |
| 102 | + message += problem_statement |
| 103 | + |
| 104 | + try: |
| 105 | + result = agent.run(prompt=message) |
| 106 | + except Exception as agent_error: |
| 107 | + pprint.pprint(f"Instance ID: {instance_id} terminated with error: {agent_error}") |
| 108 | + raise agent_error |
| 109 | + |
| 110 | + # Get the diff between the current state and the original commit |
| 111 | + model_patch = codebase.get_diff(base=base_commit) |
| 112 | + pprint.pprint(model_patch) |
| 113 | + |
| 114 | + # Record the results for the logs |
| 115 | + result = dict( |
| 116 | + # Required args for running eval tests |
| 117 | + instance_id=instance_id, |
| 118 | + model_patch=model_patch, |
| 119 | + # For computing stats |
| 120 | + gold_files=gold_files, |
| 121 | + edited_files=files_in_patch(model_patch), |
| 122 | + ) |
| 123 | + |
| 124 | + # Did we get a successful patch? |
| 125 | + if not model_patch: |
| 126 | + pprint.pprint("=" * 60) |
| 127 | + pprint.pprint("Failed to generate a patch") |
| 128 | + pprint.pprint("=" * 60) |
| 129 | + |
| 130 | + return result |
| 131 | + |
| 132 | + |
| 133 | +def process_instances(dataset: dict[str, SweBenchExample], threads: int): |
| 134 | + """Dataset - The subset of the SWE Bench dataset to process. |
| 135 | + threads - How many problems to attempt concurrently. |
| 136 | + prior_dnames - Names of predictions/ dirnames from previous runs. |
| 137 | + If they contain a plausible solution for an instance, |
| 138 | + don't continue looking. |
| 139 | + """ |
| 140 | + # Create the predictions directory if it doesn't exist |
| 141 | + PREDS_DNAME.mkdir(exist_ok=True) |
| 142 | + out_dname = PREDS_DNAME / "results" |
| 143 | + out_dname.mkdir(exist_ok=True) |
| 144 | + |
| 145 | + pprint.pprint(out_dname) |
| 146 | + |
| 147 | + # If we are restarting this run, figure out which instances are already done. |
| 148 | + done_preds = load_predictions([out_dname]) |
| 149 | + done_instances = set(done_preds.keys()) |
| 150 | + pprint.pprint(len(done_instances)) |
| 151 | + |
| 152 | + all_instances = set(dataset.keys()) |
| 153 | + |
| 154 | + remaining_instances = set(all_instances) |
| 155 | + remaining_instances -= done_instances |
| 156 | + |
| 157 | + remaining_instances = list(remaining_instances) |
| 158 | + random.shuffle(remaining_instances) |
| 159 | + |
| 160 | + pprint.pprint(sorted(remaining_instances)) |
| 161 | + pprint.pprint(len(remaining_instances)) |
| 162 | + |
| 163 | + print() |
| 164 | + print("press enter...") |
| 165 | + input() |
| 166 | + |
| 167 | + if threads > 1: |
| 168 | + process_one_instance_lox = lox.process(threads)(run_agent_on_entry) |
| 169 | + process_one_instance_func = process_one_instance_lox.scatter |
| 170 | + gather = process_one_instance_lox.gather |
| 171 | + else: |
| 172 | + process_one_instance_func = run_agent_on_entry |
| 173 | + |
| 174 | + for instance_id in remaining_instances: |
| 175 | + if instance_id in done_instances: |
| 176 | + print("skipping", instance_id) |
| 177 | + continue |
| 178 | + |
| 179 | + result = process_one_instance_func( |
| 180 | + dataset[instance_id], |
| 181 | + ) |
| 182 | + with open(out_dname / f"{instance_id}.json", "w") as f: |
| 183 | + json.dump(result, f) |
| 184 | + |
| 185 | + print("#" * 60) |
| 186 | + # input() |
| 187 | + |
| 188 | + if threads > 1: |
| 189 | + gather() |
| 190 | + |
| 191 | + |
| 192 | +def main(): |
| 193 | + # Load the SWE Bench dataset |
| 194 | + dataset = {example.instance_id: example for example in get_swe_bench_examples()} |
| 195 | + process_instances(dataset, threads=10) |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + status = main() |
| 200 | + sys.exit(status) |
0 commit comments