1
1
import git
2
2
import os
3
3
import re
4
- import subprocess
5
4
from dataclasses import asdict
6
5
from pathlib import Path
7
6
from typing import List
7
+ import fitz
8
8
9
9
from baselines .class_types import AgentConfig
10
10
13
13
REPO_INFO_HEADER = "\n \n >>> Here is the Repository Information:\n "
14
14
UNIT_TESTS_INFO_HEADER = "\n \n >>> Here are the Unit Tests Information:\n "
15
15
LINT_INFO_HEADER = "\n \n >>> Here is the Lint Information:\n "
16
-
16
+ SPEC_INFO_HEADER = " \n \n >>> Here is the Specification Information: \n "
17
17
# prefix components:
18
18
space = " "
19
19
branch = "│ "
@@ -122,14 +122,14 @@ def get_target_edit_files(target_dir: str) -> list[str]:
122
122
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
123
123
HERE')'.
124
124
"""
125
- # The grep command
126
- command = f"grep -R -l \" NotImplementedError('IMPLEMENT ME HERE') \" { target_dir } "
127
-
128
- # Run the command and capture the output
129
- result = subprocess . run ( command , shell = True , capture_output = True , text = True )
130
-
131
- # Split the output into lines and remove the base_dir prefix
132
- files = result . stdout . strip (). split ( " \n " )
125
+ files = []
126
+ for root , _ , filenames in os . walk ( target_dir ):
127
+ for filename in filenames :
128
+ if filename . endswith ( ".py" ):
129
+ file_path = os . path . join ( root , filename )
130
+ with open ( file_path , "r" ) as file :
131
+ if "NotImplementedError('IMPLEMENT ME HERE')" in file . read ():
132
+ files . append ( file_path )
133
133
134
134
# Remove the base_dir prefix
135
135
files = [file .replace (target_dir , "" ).lstrip ("/" ) for file in files ]
@@ -143,7 +143,8 @@ def get_target_edit_files(target_dir: str) -> list[str]:
143
143
def get_message (
144
144
agent_config : AgentConfig ,
145
145
repo_path : str ,
146
- test_dir : str ,
146
+ test_dir : str | None = None ,
147
+ test_file : str | None = None ,
147
148
) -> str :
148
149
"""Get the message to Aider."""
149
150
prompt = f"{ PROMPT_HEADER } " + agent_config .user_prompt
@@ -157,6 +158,13 @@ def get_message(
157
158
include_stubs = True ,
158
159
)[: agent_config .max_unit_tests_info_length ]
159
160
)
161
+ elif agent_config .use_unit_tests_info and test_file :
162
+ unit_tests_info = (
163
+ f"\n { UNIT_TESTS_INFO_HEADER } "
164
+ + get_file_info (
165
+ file_path = Path (os .path .join (repo_path , test_file )), prefix = ""
166
+ )[: agent_config .max_unit_tests_info_length ]
167
+ )
160
168
else :
161
169
unit_tests_info = ""
162
170
@@ -171,15 +179,34 @@ def get_message(
171
179
else :
172
180
repo_info = ""
173
181
174
- message_to_agent = prompt + repo_info + unit_tests_info
182
+ if agent_config .use_spec_info :
183
+ spec_info = (
184
+ f"\n { SPEC_INFO_HEADER } "
185
+ + get_specification (specification_pdf_path = Path (repo_path , "spec.pdf" ))[
186
+ : agent_config .max_spec_info_length
187
+ ]
188
+ )
189
+ else :
190
+ spec_info = ""
191
+
192
+ message_to_agent = prompt + repo_info + unit_tests_info + spec_info
175
193
176
194
return message_to_agent
177
195
178
196
179
- def get_reference (specification_pdf_path : str ) -> str :
197
+ def get_specification (specification_pdf_path : Path ) -> str :
180
198
"""Get the reference for a given specification PDF path."""
181
199
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
182
- return f"/pdf { specification_pdf_path } "
200
+ # Open the specified PDF file
201
+ document = fitz .open (specification_pdf_path )
202
+ text = ""
203
+
204
+ # Iterate through the pages
205
+ for page_num in range (len (document )):
206
+ page = document .load_page (page_num ) # loads the specified page
207
+ text += page .get_text () # type: ignore
208
+
209
+ return text
183
210
184
211
185
212
def create_branch (repo : git .Repo , branch : str , from_commit : str ) -> None :
0 commit comments