Skip to content

Commit 954d5ce

Browse files
authored
fix: arranges tools properly (#478)
1 parent ce10045 commit 954d5ce

19 files changed

+689
-53
lines changed

src/codegen/extensions/langchain/agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
from .tools import (
1414
CommitTool,
1515
CreateFileTool,
16+
CreatePRCommentTool,
17+
CreatePRReviewCommentTool,
18+
CreatePRTool,
1619
DeleteFileTool,
1720
EditFileTool,
21+
GetPRcontentsTool,
1822
ListDirectoryTool,
1923
MoveSymbolTool,
2024
RenameFileTool,
@@ -64,6 +68,10 @@ def create_codebase_agent(
6468
SemanticEditTool(codebase),
6569
SemanticSearchTool(codebase),
6670
CommitTool(codebase),
71+
CreatePRTool(codebase),
72+
GetPRcontentsTool(codebase),
73+
CreatePRCommentTool(codebase),
74+
CreatePRReviewCommentTool(codebase),
6775
]
6876

6977
# Get the prompt to use

src/codegen/extensions/langchain/tools.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Langchain tools for workspace operations."""
22

33
import json
4-
import uuid
54
from typing import ClassVar, Literal, Optional
65

76
from langchain.tools import BaseTool
@@ -12,6 +11,9 @@
1211
from ..tools import (
1312
commit,
1413
create_file,
14+
create_pr,
15+
create_pr_comment,
16+
create_pr_review_comment,
1517
delete_file,
1618
edit_file,
1719
list_directory,
@@ -22,6 +24,7 @@
2224
semantic_edit,
2325
semantic_search,
2426
view_file,
27+
view_pr,
2528
)
2629

2730

@@ -205,12 +208,11 @@ def _run(
205208
collect_dependencies: bool = True,
206209
collect_usages: bool = True,
207210
) -> str:
208-
# Find the symbol first
209-
found_symbol = self.codebase.get_symbol(symbol_name)
210211
result = reveal_symbol(
211-
found_symbol,
212-
degree,
213-
max_tokens,
212+
codebase=self.codebase,
213+
symbol_name=symbol_name,
214+
degree=degree,
215+
max_tokens=max_tokens,
214216
collect_dependencies=collect_dependencies,
215217
collect_usages=collect_usages,
216218
)
@@ -356,11 +358,8 @@ def __init__(self, codebase: Codebase) -> None:
356358
super().__init__(codebase=codebase)
357359

358360
def _run(self, title: str, body: str) -> str:
359-
if self.codebase._op.git_cli.active_branch.name == self.codebase._op.default_branch:
360-
# If the current checked out branch is the default branch, checkout onto a new branch
361-
self.codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True)
362-
pr = self.codebase.create_pr(title=title, body=body)
363-
return pr.html_url
361+
result = create_pr(self.codebase, title, body)
362+
return json.dumps(result, indent=2)
364363

365364

366365
class GetPRContentsInput(BaseModel):
@@ -381,11 +380,7 @@ def __init__(self, codebase: Codebase) -> None:
381380
super().__init__(codebase=codebase)
382381

383382
def _run(self, pr_id: int) -> str:
384-
modified_symbols, patch = self.codebase.get_modified_symbols_in_pr(pr_id)
385-
386-
# Convert modified_symbols set to list for JSON serialization
387-
result = {"modified_symbols": list(modified_symbols), "patch": patch}
388-
383+
result = view_pr(self.codebase, pr_id)
389384
return json.dumps(result, indent=2)
390385

391386

@@ -408,8 +403,8 @@ def __init__(self, codebase: Codebase) -> None:
408403
super().__init__(codebase=codebase)
409404

410405
def _run(self, pr_number: int, body: str) -> str:
411-
self.codebase.create_pr_comment(pr_number=pr_number, body=body)
412-
return "Comment created successfully"
406+
result = create_pr_comment(self.codebase, pr_number, body)
407+
return json.dumps(result, indent=2)
413408

414409

415410
class CreatePRReviewCommentInput(BaseModel):
@@ -445,7 +440,8 @@ def _run(
445440
side: str | None = None,
446441
start_line: int | None = None,
447442
) -> str:
448-
self.codebase.create_pr_review_comment(
443+
result = create_pr_review_comment(
444+
self.codebase,
449445
pr_number=pr_number,
450446
body=body,
451447
commit_sha=commit_sha,
@@ -454,7 +450,7 @@ def _run(
454450
side=side,
455451
start_line=start_line,
456452
)
457-
return "Review comment created successfully"
453+
return json.dumps(result, indent=2)
458454

459455

460456
def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
@@ -476,8 +472,11 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
476472
EditFileTool(codebase),
477473
GetPRcontentsTool(codebase),
478474
ListDirectoryTool(codebase),
475+
MoveSymbolTool(codebase),
476+
RenameFileTool(codebase),
479477
RevealSymbolTool(codebase),
480478
SearchTool(codebase),
481479
SemanticEditTool(codebase),
480+
SemanticSearchTool(codebase),
482481
ViewFileTool(codebase),
483482
]

src/codegen/extensions/mcp/codebase_tools.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
mcp = FastMCP(
1212
"codebase-tools-mcp",
13-
instructions="Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase. Use this tool for all questions, queries regarding your codebase.",
13+
instructions="""Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase.
14+
Use this tool for all questions, queries regarding your codebase.""",
1415
)
1516

1617

@@ -20,21 +21,16 @@ def reveal_symbol_tool(
2021
target_file: Annotated[Optional[str], "The file path of the file containing the symbol to inspect"],
2122
codebase_dir: Annotated[str, "The root directory of your codebase"],
2223
codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"],
23-
degree: Annotated[Optional[int], "depth do which symbol information is retrieved"],
24+
max_depth: Annotated[Optional[int], "depth up to which symbol information is retrieved"],
2425
collect_dependencies: Annotated[Optional[bool], "includes dependencies of symbol"],
2526
collect_usages: Annotated[Optional[bool], "includes usages of symbol"],
2627
):
2728
codebase = Codebase(repo_path=codebase_dir, programming_language=codebase_language)
28-
found_symbol = None
29-
if target_file:
30-
file = codebase.get_file(target_file)
31-
found_symbol = file.get_symbol(symbol_name)
32-
else:
33-
found_symbol = codebase.get_symbol(symbol_name)
34-
3529
result = reveal_symbol(
36-
found_symbol,
37-
degree,
30+
codebase=codebase,
31+
symbol_name=symbol_name,
32+
filepath=target_file,
33+
max_depth=max_depth,
3834
collect_dependencies=collect_dependencies,
3935
collect_usages=collect_usages,
4036
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Tools
2+
3+
- should take in a `codebase` and string args
4+
- gets "wrapped" by extensions, e.g. MCP or Langchain
Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,42 @@
11
"""Tools for workspace operations."""
22

3-
from .file_operations import (
4-
commit,
5-
create_file,
6-
delete_file,
7-
edit_file,
8-
list_directory,
9-
move_symbol,
10-
rename_file,
11-
view_file,
12-
)
3+
from .commit import commit
4+
from .create_file import create_file
5+
from .delete_file import delete_file
6+
from .edit_file import edit_file
7+
from .github.create_pr import create_pr
8+
from .github.create_pr_comment import create_pr_comment
9+
from .github.create_pr_review_comment import create_pr_review_comment
10+
from .github.view_pr import view_pr
11+
from .list_directory import list_directory
12+
from .move_symbol import move_symbol
13+
from .rename_file import rename_file
1314
from .reveal_symbol import reveal_symbol
1415
from .search import search
1516
from .semantic_edit import semantic_edit
1617
from .semantic_search import semantic_search
18+
from .view_file import view_file
1719

1820
__all__ = [
21+
# Git operations
1922
"commit",
23+
# File operations
2024
"create_file",
25+
"create_pr",
26+
"create_pr_comment",
27+
"create_pr_review_comment",
2128
"delete_file",
2229
"edit_file",
2330
"list_directory",
24-
# Symbol analysis
31+
# Symbol operations
2532
"move_symbol",
26-
# File operations
2733
"rename_file",
2834
"reveal_symbol",
29-
# Search
35+
# Search operations
3036
"search",
31-
# Semantic edit
37+
# Edit operations
3238
"semantic_edit",
3339
"semantic_search",
3440
"view_file",
41+
"view_pr",
3542
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Tool for committing changes to disk."""
2+
3+
from typing import Any
4+
5+
from codegen import Codebase
6+
7+
8+
def commit(codebase: Codebase) -> dict[str, Any]:
9+
"""Commit any pending changes to disk.
10+
11+
Args:
12+
codebase: The codebase to operate on
13+
14+
Returns:
15+
Dict containing commit status
16+
"""
17+
codebase.commit()
18+
return {"status": "success", "message": "Changes committed to disk"}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Tool for creating new files."""
2+
3+
from typing import Any
4+
5+
from codegen import Codebase
6+
7+
from .view_file import view_file
8+
9+
10+
def create_file(codebase: Codebase, filepath: str, content: str = "") -> dict[str, Any]:
11+
"""Create a new file.
12+
13+
Args:
14+
codebase: The codebase to operate on
15+
filepath: Path where to create the file
16+
content: Initial file content
17+
18+
Returns:
19+
Dict containing new file state, or error information if file already exists
20+
"""
21+
if codebase.has_file(filepath):
22+
return {"error": f"File already exists: {filepath}"}
23+
file = codebase.create_file(filepath, content=content)
24+
codebase.commit()
25+
return view_file(codebase, filepath)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Tool for deleting files."""
2+
3+
from typing import Any
4+
5+
from codegen import Codebase
6+
7+
8+
def delete_file(codebase: Codebase, filepath: str) -> dict[str, Any]:
9+
"""Delete a file.
10+
11+
Args:
12+
codebase: The codebase to operate on
13+
filepath: Path to the file to delete
14+
15+
Returns:
16+
Dict containing deletion status, or error information if file not found
17+
"""
18+
try:
19+
file = codebase.get_file(filepath)
20+
except ValueError:
21+
return {"error": f"File not found: {filepath}"}
22+
if file is None:
23+
return {"error": f"File not found: {filepath}"}
24+
25+
file.remove()
26+
codebase.commit()
27+
return {"status": "success", "deleted_file": filepath}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Tool for editing file contents."""
2+
3+
from typing import Any
4+
5+
from codegen import Codebase
6+
7+
from .view_file import view_file
8+
9+
10+
def edit_file(codebase: Codebase, filepath: str, content: str) -> dict[str, Any]:
11+
"""Edit a file by replacing its entire content.
12+
13+
Args:
14+
codebase: The codebase to operate on
15+
filepath: Path to the file to edit
16+
content: New content for the file
17+
18+
Returns:
19+
Dict containing updated file state, or error information if file not found
20+
"""
21+
try:
22+
file = codebase.get_file(filepath)
23+
except ValueError:
24+
return {"error": f"File not found: {filepath}"}
25+
if file is None:
26+
return {"error": f"File not found: {filepath}"}
27+
28+
file.edit(content)
29+
codebase.commit()
30+
return view_file(codebase, filepath)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Tool for creating pull requests."""
2+
3+
import uuid
4+
from typing import Any
5+
6+
from codegen import Codebase
7+
8+
9+
def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]:
10+
"""Create a PR for the current branch.
11+
12+
Args:
13+
codebase: The codebase to operate on
14+
title: The title of the PR
15+
body: The body/description of the PR
16+
17+
Returns:
18+
Dict containing PR info, or error information if operation fails
19+
"""
20+
try:
21+
# If on default branch, create a new branch
22+
if codebase._op.git_cli.active_branch.name == codebase._op.default_branch:
23+
codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True)
24+
25+
# Create the PR
26+
pr = codebase.create_pr(title=title, body=body)
27+
return {
28+
"status": "success",
29+
"url": pr.html_url,
30+
"number": pr.number,
31+
"title": pr.title,
32+
}
33+
except Exception as e:
34+
return {"error": f"Failed to create PR: {e!s}"}

0 commit comments

Comments
 (0)