Skip to content

Commit 82a9ea4

Browse files
authored
Merge pull request #72 from commit-0/topo-sort
add Topo sort
2 parents 50a9bfc + c34bde7 commit 82a9ea4

File tree

7 files changed

+169
-24
lines changed

7 files changed

+169
-24
lines changed

agent/agent_utils.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pathlib import Path
77
from typing import List
88
import fitz
9+
from import_deps import ModuleSet
10+
from graphlib import TopologicalSorter, CycleError
911
import yaml
1012

1113
from agent.class_types import AgentConfig
@@ -16,6 +18,7 @@
1618
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
1719
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
1820
SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n"
21+
IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n"
1922
# prefix components:
2023
space = " "
2124
branch = "│ "
@@ -190,25 +193,97 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]
190193
return files
191194

192195

193-
def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]:
196+
def ignore_cycles(graph: dict) -> list[str]:
197+
"""Ignore the cycles in the graph."""
198+
ts = TopologicalSorter(graph)
199+
try:
200+
return list(ts.static_order())
201+
except CycleError as e:
202+
# print(f"Cycle detected: {e.args[1]}")
203+
# You can either break the cycle by modifying the graph or handle it as needed.
204+
# For now, let's just remove the first node in the cycle and try again.
205+
cycle_nodes = e.args[1]
206+
node_to_remove = cycle_nodes[0]
207+
# print(f"Removing node {node_to_remove} to resolve cycle.")
208+
graph.pop(node_to_remove, None)
209+
return ignore_cycles(graph)
210+
211+
212+
def topological_sort_based_on_dependencies(
213+
pkg_paths: list[str],
214+
) -> tuple[list[str], dict]:
215+
"""Topological sort based on dependencies."""
216+
module_set = ModuleSet([str(p) for p in pkg_paths])
217+
218+
import_dependencies = {}
219+
for path in sorted(module_set.by_path.keys()):
220+
module_name = ".".join(module_set.by_path[path].fqn)
221+
mod = module_set.by_name[module_name]
222+
try:
223+
imports = module_set.get_imports(mod)
224+
import_dependencies[path] = set([str(x) for x in imports])
225+
except Exception:
226+
import_dependencies[path] = set()
227+
228+
import_dependencies_files = ignore_cycles(import_dependencies)
229+
230+
return import_dependencies_files, import_dependencies
231+
232+
233+
def get_target_edit_files(
234+
local_repo: git.Repo,
235+
src_dir: str,
236+
test_dir: str,
237+
latest_commit: str,
238+
reference_commit: str,
239+
) -> tuple[list[str], dict]:
194240
"""Find the files with functions with the pass statement."""
241+
target_dir = str(local_repo.working_dir)
195242
files = _find_files_to_edit(target_dir, src_dir, test_dir)
196243
filtered_files = []
197244
for file_path in files:
198-
with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
245+
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
199246
content = file.read()
200247
if len(content.splitlines()) > 1500:
201248
continue
202249
if " pass" in content:
203250
filtered_files.append(file_path)
251+
# Change to reference commit to get the correct dependencies
252+
local_repo.git.checkout(reference_commit)
253+
254+
topological_sort_files, import_dependencies = (
255+
topological_sort_based_on_dependencies(filtered_files)
256+
)
257+
if len(topological_sort_files) != len(filtered_files):
258+
if len(topological_sort_files) < len(filtered_files):
259+
# Find the missing elements
260+
missing_files = set(filtered_files) - set(topological_sort_files)
261+
# Add the missing files to the end of the list
262+
topological_sort_files = topological_sort_files + list(missing_files)
263+
else:
264+
raise ValueError(
265+
"topological_sort_files should not be longer than filtered_files"
266+
)
267+
assert len(topological_sort_files) == len(
268+
filtered_files
269+
), "all files should be included"
270+
271+
# change to latest commit
272+
local_repo.git.checkout(latest_commit)
204273

205274
# Remove the base_dir prefix
206-
filtered_files = [
207-
file.replace(target_dir, "").lstrip("/") for file in filtered_files
275+
topological_sort_files = [
276+
file.replace(target_dir, "").lstrip("/") for file in topological_sort_files
208277
]
209-
# Only keep python files
210278

211-
return filtered_files
279+
# Remove the base_dir prefix from import dependencies
280+
import_dependencies_without_prefix = {}
281+
for key, value in import_dependencies.items():
282+
key_without_prefix = key.replace(target_dir, "").lstrip("/")
283+
value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value]
284+
import_dependencies_without_prefix[key_without_prefix] = value_without_prefix
285+
286+
return topological_sort_files, import_dependencies_without_prefix
212287

213288

214289
def get_message(
@@ -268,6 +343,20 @@ def get_message(
268343
return message_to_agent
269344

270345

346+
def update_message_with_dependencies(message: str, dependencies: list[str]) -> str:
347+
"""Update the message with the dependencies."""
348+
if len(dependencies) == 0:
349+
return message
350+
import_dependencies_info = f"\n{IMPORT_DEPENDENCIES_HEADER}"
351+
for dependency in dependencies:
352+
with open(dependency, "r") as file:
353+
import_dependencies_info += (
354+
f"\nHere is the content of the file {dependency}:\n{file.read()}"
355+
)
356+
message += import_dependencies_info
357+
return message
358+
359+
271360
def get_specification(specification_pdf_path: Path) -> str:
272361
"""Get the reference for a given specification PDF path."""
273362
# TODO: after pdf_to_text is available, use it to extract the text from the PDF

agent/agents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def run(
9090
sys.stdout = open(log_file, "a")
9191
sys.stderr = open(log_file, "a")
9292

93+
# Log the message
94+
agent_message_log_file = log_dir / "agent_message.log"
95+
with open(agent_message_log_file, "a") as f:
96+
f.write(f"Message Sent: {message}\n\n")
97+
9398
# Configure httpx and backoff logging
9499
handle_logging("httpx", log_file)
95100
handle_logging("backoff", log_file)

agent/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def run(
178178
".agent.yaml",
179179
help="Path to the agent config file",
180180
),
181+
commit0_config_file: str = typer.Option(
182+
".commit0.yaml",
183+
help="Path to the commit0 config file",
184+
),
181185
log_dir: str = typer.Option(
182186
str(RUN_AGENT_LOG_DIR.resolve()),
183187
help="Log directory to store the logs",
@@ -202,6 +206,7 @@ def run(
202206
override_previous_changes,
203207
backend,
204208
agent_config_file,
209+
commit0_config_file,
205210
log_dir,
206211
max_parallel_repos,
207212
display_repo_progress_num,
@@ -212,6 +217,7 @@ def run(
212217
override_previous_changes,
213218
backend,
214219
agent_config_file,
220+
commit0_config_file,
215221
log_dir,
216222
max_parallel_repos,
217223
)

agent/display.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from rich.align import Align
1818
from collections import OrderedDict
1919
from types import TracebackType
20+
import json
21+
from datetime import datetime
2022

2123

2224
class RepoBox:
@@ -404,3 +406,29 @@ def __exit__(
404406
f"{'Total':<30} {self.total_time_spent:>13.2f}s {total_files:>18} {total_money:>13.2f}$"
405407
)
406408
print("-" * 80)
409+
410+
# Write summary to JSON file
411+
412+
summary_data = {
413+
"timestamp": datetime.now().isoformat(),
414+
"total_time_spent": self.total_time_spent,
415+
"total_files_processed": total_files,
416+
"total_money_spent": total_money,
417+
"repositories": [
418+
{
419+
"name": repo_name,
420+
"time_spent": self.end_time_per_repo[repo_name]
421+
- self.start_time_per_repo[repo_name],
422+
"files_processed": self.total_files_per_repo[repo_name],
423+
"money_spent": sum(
424+
self.repo_money_spent.get(repo_name, {}).values()
425+
),
426+
}
427+
for repo_name in self.end_time_per_repo
428+
],
429+
}
430+
431+
with open("processing_summary.json", "w") as json_file:
432+
json.dump(summary_data, json_file, indent=4)
433+
434+
print("\nSummary has been written to processing_summary.json")

agent/run_agent.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
create_branch,
88
get_message,
99
get_target_edit_files,
10+
update_message_with_dependencies,
1011
get_lint_cmd,
1112
read_yaml_config,
1213
)
@@ -66,13 +67,6 @@ def run_agent_for_repo(
6667
repo_path = os.path.join(repo_base_dir, repo_name)
6768
repo_path = os.path.abspath(repo_path)
6869

69-
target_edit_files = get_target_edit_files(
70-
repo_path, example["src_dir"], example["test"]["test_dir"]
71-
)
72-
# Call the commit0 get-tests command to retrieve test files
73-
test_files_str = get_tests(repo_name, verbose=0)
74-
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
75-
7670
try:
7771
local_repo = Repo(repo_path)
7872
except Exception:
@@ -90,7 +84,6 @@ def run_agent_for_repo(
9084
# # if branch_name is not provided, create a new branch name based on agent_config
9185
# if branch is None:
9286
# branch = args2string(agent_config)
93-
9487
create_branch(local_repo, branch, example["base_commit"])
9588

9689
# in cases where the latest commit of branch is not commit 0
@@ -99,6 +92,17 @@ def run_agent_for_repo(
9992
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
10093
local_repo.git.reset("--hard", example["base_commit"])
10194

95+
target_edit_files, import_dependencies = get_target_edit_files(
96+
local_repo,
97+
example["src_dir"],
98+
example["test"]["test_dir"],
99+
str(latest_commit),
100+
example["reference_commit"],
101+
)
102+
# Call the commit0 get-tests command to retrieve test files
103+
test_files_str = get_tests(repo_name, verbose=0)
104+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
105+
102106
# prepare the log dir
103107
experiment_log_dir = (
104108
Path(log_dir)
@@ -158,6 +162,8 @@ def run_agent_for_repo(
158162
)
159163
for f in target_edit_files:
160164
update_queue.put(("set_current_file", (repo_name, f)))
165+
dependencies = import_dependencies[f]
166+
message = update_message_with_dependencies(message, dependencies)
161167
file_name = f.replace(".py", "").replace("/", "__")
162168
file_log_dir = experiment_log_dir / file_name
163169
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
@@ -176,6 +182,7 @@ def run_agent(
176182
override_previous_changes: bool,
177183
backend: str,
178184
agent_config_file: str,
185+
commit0_config_file: str,
179186
log_dir: str,
180187
max_parallel_repos: int,
181188
display_repo_progress_num: int,
@@ -185,7 +192,7 @@ def run_agent(
185192

186193
agent_config = AgentConfig(**config)
187194

188-
commit0_config = read_commit0_dot_file(".commit0.yaml")
195+
commit0_config = read_commit0_dot_file(commit0_config_file)
189196

190197
dataset = load_dataset(
191198
commit0_config["dataset_name"], split=commit0_config["dataset_split"]

agent/run_agent_no_rich.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
create_branch,
1010
get_message,
1111
get_target_edit_files,
12+
update_message_with_dependencies,
1213
get_lint_cmd,
1314
read_yaml_config,
1415
)
@@ -61,14 +62,6 @@ def run_agent_for_repo(
6162
repo_path = os.path.join(repo_base_dir, repo_name)
6263
repo_path = os.path.abspath(repo_path)
6364

64-
# get target files to edit and test files to run
65-
target_edit_files = get_target_edit_files(
66-
repo_path, example["src_dir"], example["test"]["test_dir"]
67-
)
68-
# Call the commit0 get-tests command to retrieve test files
69-
test_files_str = get_tests(repo_name, verbose=0)
70-
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
71-
7265
try:
7366
local_repo = Repo(repo_path)
7467
except Exception:
@@ -95,6 +88,19 @@ def run_agent_for_repo(
9588
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
9689
local_repo.git.reset("--hard", example["base_commit"])
9790

91+
# get target files to edit and test files to run
92+
target_edit_files, import_dependencies = get_target_edit_files(
93+
local_repo,
94+
example["src_dir"],
95+
example["test"]["test_dir"],
96+
str(latest_commit),
97+
str(example["reference_commit"]),
98+
)
99+
100+
# Call the commit0 get-tests command to retrieve test files
101+
test_files_str = get_tests(repo_name, verbose=0)
102+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
103+
98104
# prepare the log dir
99105
experiment_log_dir = (
100106
Path(log_dir)
@@ -139,6 +145,8 @@ def run_agent_for_repo(
139145
agent_config, repo_path, test_dir=example["test"]["test_dir"]
140146
)
141147
for f in target_edit_files:
148+
dependencies = import_dependencies[f]
149+
message = update_message_with_dependencies(message, dependencies)
142150
file_name = f.replace(".py", "").replace("/", "__")
143151
file_log_dir = experiment_log_dir / file_name
144152
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
@@ -151,6 +159,7 @@ def run_agent(
151159
override_previous_changes: bool,
152160
backend: str,
153161
agent_config_file: str,
162+
commit0_config_file: str,
154163
log_dir: str,
155164
max_parallel_repos: int,
156165
) -> None:
@@ -162,7 +171,7 @@ def run_agent(
162171

163172
agent_config = AgentConfig(**config)
164173

165-
commit0_config = read_commit0_dot_file(".commit0.yaml")
174+
commit0_config = read_commit0_dot_file(commit0_config_file)
166175

167176
dataset = load_dataset(
168177
commit0_config["dataset_name"], split=commit0_config["dataset_split"]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ requires-python = ">=3.11"
1111
dependencies = [
1212
"ruff>=0.6.4",
1313
"pre-commit>=3.8.0",
14+
"import-deps>=0.3.0",
1415
"PyMuPDF>=1.24.5",
1516
"modal==0.64.95",
1617
"typer>=0.12.0",

0 commit comments

Comments
 (0)