Skip to content

Commit 2981829

Browse files
Zeeeeparushilpatel0
andcommitted
Apply changes from commit 046b238
Original commit by Tawsif Kamal: Revert "Revert "Adding Schema for Tool Outputs"" (codegen-sh#894) Reverts codegen-sh#892 --------- Co-authored-by: Rushil Patel <[email protected]> Co-authored-by: rushilpatel0 <[email protected]>
1 parent aed3fe0 commit 2981829

File tree

12 files changed

+457
-71
lines changed

12 files changed

+457
-71
lines changed

src/codegen/agents/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class ToolMessageData(BaseMessage):
5454
tool_id: Optional[str] = None
5555
status: Optional[str] = None
5656

57+
5758
@dataclass
5859
class FunctionMessageData(BaseMessage):
5960
"""Represents a function message."""

src/codegen/agents/tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
7878
tool_response=getattr(latest_message, "artifact", content),
7979
tool_id=getattr(latest_message, "tool_call_id", None),
8080
status=getattr(latest_message, "status", None),
81-
) elif message_type == "function":
81+
)
82+
elif message_type == "function":
8283
return FunctionMessageData(type=message_type, content=content)
8384
else:
8485
return UnknownMessage(type=message_type, content=content)

src/codegen/extensions/langchain/tools.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Langchain tools for workspace operations."""
22

33
from collections.abc import Callable
4-
from typing import ClassVar, Literal
4+
from typing import Annotated, ClassVar, Literal, Optional
55

6+
from langchain_core.messages import ToolMessage
7+
from langchain_core.tools import InjectedToolCallId
68
from langchain_core.tools.base import BaseTool
79
from pydantic import BaseModel, Field
810

@@ -52,10 +54,11 @@ class ViewFileInput(BaseModel):
5254
"""Input for viewing a file."""
5355

5456
filepath: str = Field(..., description="Path to the file relative to workspace root")
55-
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
56-
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
57-
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 500")
58-
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")
57+
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
58+
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
59+
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 500")
60+
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
61+
tool_call_id: Annotated[str, InjectedToolCallId]
5962

6063

6164
class ViewFileTool(BaseTool):
@@ -73,12 +76,13 @@ def __init__(self, codebase: Codebase) -> None:
7376

7477
def _run(
7578
self,
79+
tool_call_id: str,
7680
filepath: str,
77-
start_line: int | None = None,
78-
end_line: int | None = None,
79-
max_lines: int | None = None,
80-
line_numbers: bool | None = True,
81-
) -> str:
81+
start_line: Optional[int] = None,
82+
end_line: Optional[int] = None,
83+
max_lines: Optional[int] = None,
84+
line_numbers: Optional[bool] = True,
85+
) -> ToolMessage:
8286
result = view_file(
8387
self.codebase,
8488
filepath,
@@ -88,14 +92,15 @@ def _run(
8892
max_lines=max_lines if max_lines is not None else 500,
8993
)
9094

91-
return result.render()
95+
return result.render(tool_call_id)
9296

9397

9498
class ListDirectoryInput(BaseModel):
9599
"""Input for listing directory contents."""
96100

97101
dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
98102
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
103+
tool_call_id: Annotated[str, InjectedToolCallId]
99104

100105

101106
class ListDirectoryTool(BaseTool):
@@ -109,9 +114,9 @@ class ListDirectoryTool(BaseTool):
109114
def __init__(self, codebase: Codebase) -> None:
110115
super().__init__(codebase=codebase)
111116

112-
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
117+
def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
113118
result = list_directory(self.codebase, dirpath, depth)
114-
return result.render()
119+
return result.render(tool_call_id)
115120

116121

117122
class SearchInput(BaseModel):
@@ -126,6 +131,7 @@ class SearchInput(BaseModel):
126131
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
127132
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
128133
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
134+
tool_call_id: Annotated[str, InjectedToolCallId]
129135

130136

131137
class SearchTool(BaseTool):
@@ -139,16 +145,17 @@ class SearchTool(BaseTool):
139145
def __init__(self, codebase: Codebase) -> None:
140146
super().__init__(codebase=codebase)
141147

142-
def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
148+
def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
143149
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
144-
return result.render()
150+
return result.render(tool_call_id)
145151

146152

147153
class EditFileInput(BaseModel):
148154
"""Input for editing a file."""
149155

150156
filepath: str = Field(..., description="Path to the file to edit")
151157
content: str = Field(..., description="New content for the file")
158+
tool_call_id: Annotated[str, InjectedToolCallId]
152159

153160

154161
class EditFileTool(BaseTool):
@@ -181,9 +188,9 @@ class EditFileTool(BaseTool):
181188
def __init__(self, codebase: Codebase) -> None:
182189
super().__init__(codebase=codebase)
183190

184-
def _run(self, filepath: str, content: str) -> str:
191+
def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
185192
result = edit_file(self.codebase, filepath, content)
186-
return result.render()
193+
return result.render(tool_call_id)
187194

188195

189196
class CreateFileInput(BaseModel):
@@ -340,6 +347,7 @@ class SemanticEditInput(BaseModel):
340347
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
341348
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
342349
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
350+
tool_call_id: Annotated[str, InjectedToolCallId]
343351

344352

345353
class SemanticEditTool(BaseTool):
@@ -353,10 +361,10 @@ class SemanticEditTool(BaseTool):
353361
def __init__(self, codebase: Codebase) -> None:
354362
super().__init__(codebase=codebase)
355363

356-
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
364+
def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
357365
# Create the the draft editor mini llm
358366
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
359-
return result.render()
367+
return result.render(tool_call_id)
360368

361369

362370
class RenameFileInput(BaseModel):
@@ -1033,6 +1041,7 @@ class RelaceEditInput(BaseModel):
10331041

10341042
filepath: str = Field(..., description="Path of the file relative to workspace root")
10351043
edit_snippet: str = Field(..., description=RELACE_EDIT_PROMPT)
1044+
tool_call_id: Annotated[str, InjectedToolCallId]
10361045

10371046

10381047
class RelaceEditTool(BaseTool):
@@ -1046,9 +1055,9 @@ class RelaceEditTool(BaseTool):
10461055
def __init__(self, codebase: Codebase) -> None:
10471056
super().__init__(codebase=codebase)
10481057

1049-
def _run(self, filepath: str, edit_snippet: str) -> str:
1058+
def _run(self, filepath: str, edit_snippet: str, tool_call_id: str) -> ToolMessage:
10501059
result = relace_edit(self.codebase, filepath, edit_snippet)
1051-
return result.render()
1060+
return result.render(tool_call_id=tool_call_id)
10521061

10531062

10541063
class ReflectionInput(BaseModel):

src/codegen/extensions/tools/edit_file.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,53 @@
11
"""Tool for editing file contents."""
22

3-
from typing import ClassVar
3+
from typing import TYPE_CHECKING, ClassVar, Optional
44

5+
from langchain_core.messages import ToolMessage
56
from pydantic import Field
67

78
from codegen.sdk.core.codebase import Codebase
89

910
from .observation import Observation
1011
from .replacement_edit import generate_diff
1112

13+
if TYPE_CHECKING:
14+
from .tool_output_types import EditFileArtifacts
15+
1216

1317
class EditFileObservation(Observation):
1418
"""Response from editing a file."""
1519

1620
filepath: str = Field(
1721
description="Path to the edited file",
1822
)
19-
diff: str = Field(
23+
diff: Optional[str] = Field(
24+
default=None,
2025
description="Unified diff showing the changes made",
2126
)
2227

2328
str_template: ClassVar[str] = "Edited file {filepath}"
2429

25-
def render(self) -> str:
30+
def render(self, tool_call_id: str) -> ToolMessage:
2631
"""Render edit results in a clean format."""
27-
return f"""[EDIT FILE]: {self.filepath}
28-
29-
{self.diff}"""
32+
if self.status == "error":
33+
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
34+
return ToolMessage(
35+
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
36+
status=self.status,
37+
name="edit_file",
38+
artifact=artifacts_error,
39+
tool_call_id=tool_call_id,
40+
)
41+
42+
artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}
43+
44+
return ToolMessage(
45+
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
46+
status=self.status,
47+
name="edit_file",
48+
artifact=artifacts_success,
49+
tool_call_id=tool_call_id,
50+
)
3051

3152

3253
def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:

src/codegen/extensions/tools/list_directory.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from typing import ClassVar
44

5+
from langchain_core.messages import ToolMessage
56
from pydantic import Field
67

8+
from codegen.extensions.tools.observation import Observation
9+
from codegen.extensions.tools.tool_output_types import ListDirectoryArtifacts
710
from codegen.sdk.core.codebase import Codebase
811
from codegen.sdk.core.directory import Directory
912

10-
from .observation import Observation
11-
1213

1314
class DirectoryInfo(Observation):
1415
"""Information about a directory."""
@@ -31,6 +32,14 @@ class DirectoryInfo(Observation):
3132
default=False,
3233
description="Whether this is a leaf node (at max depth)",
3334
)
35+
depth: int = Field(
36+
default=0,
37+
description="Current depth in the tree",
38+
)
39+
max_depth: int = Field(
40+
default=1,
41+
description="Maximum depth allowed",
42+
)
3443

3544
str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)"
3645

@@ -41,7 +50,7 @@ def _get_details(self) -> dict[str, int]:
4150
"dir_count": len(self.subdirectories),
4251
}
4352

44-
def render(self) -> str:
53+
def render_as_string(self) -> str:
4554
"""Render directory listing as a file tree."""
4655
lines = [
4756
f"[LIST DIRECTORY]: {self.path}",
@@ -97,6 +106,26 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st
97106

98107
return "\n".join(lines)
99108

109+
def to_artifacts(self) -> ListDirectoryArtifacts:
110+
"""Convert directory info to artifacts for UI."""
111+
artifacts: ListDirectoryArtifacts = {
112+
"dirpath": self.path,
113+
"name": self.name,
114+
"is_leaf": self.is_leaf,
115+
"depth": self.depth,
116+
"max_depth": self.max_depth,
117+
}
118+
119+
if self.files is not None:
120+
artifacts["files"] = self.files
121+
artifacts["file_paths"] = [f"{self.path}/{f}" for f in self.files]
122+
123+
if self.subdirectories:
124+
artifacts["subdirs"] = [d.name for d in self.subdirectories]
125+
artifacts["subdir_paths"] = [d.path for d in self.subdirectories]
126+
127+
return artifacts
128+
100129

101130
class ListDirectoryObservation(Observation):
102131
"""Response from listing directory contents."""
@@ -107,9 +136,29 @@ class ListDirectoryObservation(Observation):
107136

108137
str_template: ClassVar[str] = "{directory_info}"
109138

110-
def render(self) -> str:
111-
"""Render directory listing."""
112-
return self.directory_info.render()
139+
def render(self, tool_call_id: str) -> ToolMessage:
140+
"""Render directory listing with artifacts for UI."""
141+
if self.status == "error":
142+
error_artifacts: ListDirectoryArtifacts = {
143+
"dirpath": self.directory_info.path,
144+
"name": self.directory_info.name,
145+
"error": self.error,
146+
}
147+
return ToolMessage(
148+
content=f"[ERROR LISTING DIRECTORY]: {self.directory_info.path}: {self.error}",
149+
status=self.status,
150+
name="list_directory",
151+
artifact=error_artifacts,
152+
tool_call_id=tool_call_id,
153+
)
154+
155+
return ToolMessage(
156+
content=self.directory_info.render_as_string(),
157+
status=self.status,
158+
name="list_directory",
159+
artifact=self.directory_info.to_artifacts(),
160+
tool_call_id=tool_call_id,
161+
)
113162

114163

115164
def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> ListDirectoryObservation:
@@ -136,7 +185,7 @@ def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> List
136185
),
137186
)
138187

139-
def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
188+
def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -> DirectoryInfo:
140189
"""Helper function to get directory info recursively."""
141190
# Get direct files (always include files unless at max depth)
142191
all_files = []
@@ -151,7 +200,7 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
151200
if current_depth > 1 or current_depth == -1:
152201
# For deeper traversal, get full directory info
153202
new_depth = current_depth - 1 if current_depth > 1 else -1
154-
subdirs.append(get_directory_info(subdir, new_depth))
203+
subdirs.append(get_directory_info(subdir, new_depth, max_depth))
155204
else:
156205
# At max depth, return a leaf node
157206
subdirs.append(
@@ -161,6 +210,8 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
161210
path=subdir.dirpath,
162211
files=None, # Don't include files at max depth
163212
is_leaf=True,
213+
depth=current_depth,
214+
max_depth=max_depth,
164215
)
165216
)
166217

@@ -170,9 +221,11 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
170221
path=dir_obj.dirpath,
171222
files=sorted(all_files),
172223
subdirectories=subdirs,
224+
depth=current_depth,
225+
max_depth=max_depth,
173226
)
174227

175-
dir_info = get_directory_info(directory, depth)
228+
dir_info = get_directory_info(directory, depth, depth)
176229
return ListDirectoryObservation(
177230
status="success",
178231
directory_info=dir_info,

0 commit comments

Comments
 (0)