Skip to content

Commit c2a4884

Browse files
authored
Add files via upload
1 parent 4eee734 commit c2a4884

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""This is the harness for running an AI agent on the SWE Bench dataset."""
2+
3+
#!/usr/bin/env python
4+
import json
5+
import pprint
6+
import random
7+
import subprocess
8+
import sys
9+
from pathlib import Path
10+
11+
import lox
12+
13+
from codegen import Codebase
14+
from codegen.agents.code_agent import CodeAgent
15+
from codegen.configs.models.codebase import CodebaseConfig
16+
from codegen.extensions.swebench.utils import (
17+
SweBenchExample,
18+
get_swe_bench_examples,
19+
load_predictions,
20+
)
21+
22+
PARENT_DIR = Path(__file__).parent
23+
24+
PREDS_DNAME = PARENT_DIR / "predictions"
25+
26+
27+
def diff_versus_commit(git_dname, commit):
28+
"""Take a diff of `git_dname` current contents versus the `commit`."""
29+
diff_cmd = f"git -C {git_dname} diff {commit}"
30+
diff_output = subprocess.check_output(diff_cmd.split()).decode()
31+
return diff_output
32+
33+
34+
def files_in_patch(patch):
35+
"""Extract the list of modified files from a unified diff patch string."""
36+
files = []
37+
for line in patch.split("\n"):
38+
if line.startswith("--- a/") or line.startswith("+++ b/"):
39+
fname = line.split("/", 1)[1]
40+
if fname not in files:
41+
files.append(fname)
42+
return files
43+
44+
45+
def show_problems(dataset):
46+
"""Print out all the instance_id and problem_descriptions."""
47+
for inst, entry in dataset.items():
48+
problem = entry.problem_statement.splitlines()[0]
49+
print(f"{inst}: {problem}")
50+
51+
52+
def run_agent_on_entry(entry: SweBenchExample, model: str, codebase: Codebase | None = None, run_id: str | None = None):
53+
"""Process one `entry` from SWE Bench using the LLM `models` at the
54+
given `temperature`. Set `model_name_or_path` in the result json.
55+
"""
56+
instance_id = entry.instance_id
57+
base_commit = entry.base_commit
58+
59+
print("=" * 60)
60+
pprint.pprint(instance_id)
61+
print("=" * 60)
62+
problem_statement = entry.problem_statement
63+
print(problem_statement)
64+
65+
gold_files = files_in_patch(entry.patch)
66+
67+
if codebase is None:
68+
config = CodebaseConfig(
69+
disable_file_parse=True, # Disable the graph AND disable file parsing (file.edit only)
70+
)
71+
codebase = Codebase.from_repo(repo_full_name=entry.repo, commit=base_commit, language="python", config=config) # check out the repo
72+
73+
metadata = {"run_id": run_id, "instance_id": instance_id, "difficulty": f"difficulty_{entry.difficulty}"}
74+
tags = [str(value) for value in metadata.values()]
75+
agent = CodeAgent(codebase=codebase, tags=tags, metadata=metadata)
76+
77+
pprint.pprint(instance_id)
78+
pprint.pprint(gold_files)
79+
80+
message = """Below is a real GitHub issue from a popular GitHub repository.
81+
The issue was filed some time ago.
82+
The repo has been checked out at the commit that existed at the moment the issue was filed.
83+
If you are already familiar with this repo, be cautious!
84+
You are working with an old version of the repo!
85+
Filenames, directory names, file contents, etc may be different than what you're used to.
86+
87+
Propose changes to update the repo to fix the problem below.
88+
*** IMPORTANT: *** DO NOT MODIFY ANY TESTS!
89+
*** IMPORTANT: *** DO NOT ADD ANY TESTS!
90+
91+
Before commiting to do any modifications, double check your work with the Reflection tool.
92+
you can also use that tool to check your work after you think you are done.
93+
if you ever get stuck using other tools, use the Reflection tool to re asses your situation.
94+
after every file edit, use the Reflection tool to check your work and sanity check yourself.
95+
after editing a file you need to double check your work and use the ViewFiles tool to make sure you didn't break anything and that your edits are indeed correct.
96+
97+
You should follow the advices of the Reflection tool when ever they seem reasonable.
98+
99+
Also DO NOT ADD OR EDIT ANY TESTS!
100+
101+
"""
102+
message += problem_statement
103+
104+
try:
105+
result = agent.run(prompt=message)
106+
except Exception as agent_error:
107+
pprint.pprint(f"Instance ID: {instance_id} terminated with error: {agent_error}")
108+
raise agent_error
109+
110+
# Get the diff between the current state and the original commit
111+
model_patch = codebase.get_diff(base=base_commit)
112+
pprint.pprint(model_patch)
113+
114+
# Record the results for the logs
115+
result = dict(
116+
# Required args for running eval tests
117+
instance_id=instance_id,
118+
model_patch=model_patch,
119+
# For computing stats
120+
gold_files=gold_files,
121+
edited_files=files_in_patch(model_patch),
122+
)
123+
124+
# Did we get a successful patch?
125+
if not model_patch:
126+
pprint.pprint("=" * 60)
127+
pprint.pprint("Failed to generate a patch")
128+
pprint.pprint("=" * 60)
129+
130+
return result
131+
132+
133+
def process_instances(dataset: dict[str, SweBenchExample], threads: int):
134+
"""Dataset - The subset of the SWE Bench dataset to process.
135+
threads - How many problems to attempt concurrently.
136+
prior_dnames - Names of predictions/ dirnames from previous runs.
137+
If they contain a plausible solution for an instance,
138+
don't continue looking.
139+
"""
140+
# Create the predictions directory if it doesn't exist
141+
PREDS_DNAME.mkdir(exist_ok=True)
142+
out_dname = PREDS_DNAME / "results"
143+
out_dname.mkdir(exist_ok=True)
144+
145+
pprint.pprint(out_dname)
146+
147+
# If we are restarting this run, figure out which instances are already done.
148+
done_preds = load_predictions([out_dname])
149+
done_instances = set(done_preds.keys())
150+
pprint.pprint(len(done_instances))
151+
152+
all_instances = set(dataset.keys())
153+
154+
remaining_instances = set(all_instances)
155+
remaining_instances -= done_instances
156+
157+
remaining_instances = list(remaining_instances)
158+
random.shuffle(remaining_instances)
159+
160+
pprint.pprint(sorted(remaining_instances))
161+
pprint.pprint(len(remaining_instances))
162+
163+
print()
164+
print("press enter...")
165+
input()
166+
167+
if threads > 1:
168+
process_one_instance_lox = lox.process(threads)(run_agent_on_entry)
169+
process_one_instance_func = process_one_instance_lox.scatter
170+
gather = process_one_instance_lox.gather
171+
else:
172+
process_one_instance_func = run_agent_on_entry
173+
174+
for instance_id in remaining_instances:
175+
if instance_id in done_instances:
176+
print("skipping", instance_id)
177+
continue
178+
179+
result = process_one_instance_func(
180+
dataset[instance_id],
181+
)
182+
with open(out_dname / f"{instance_id}.json", "w") as f:
183+
json.dump(result, f)
184+
185+
print("#" * 60)
186+
# input()
187+
188+
if threads > 1:
189+
gather()
190+
191+
192+
def main():
193+
# Load the SWE Bench dataset
194+
dataset = {example.instance_id: example for example in get_swe_bench_examples()}
195+
process_instances(dataset, threads=10)
196+
197+
198+
if __name__ == "__main__":
199+
status = main()
200+
sys.exit(status)

0 commit comments

Comments
 (0)