Skip to content

Commit ebbfc14

Browse files
jayhackrushilpatel0
authored andcommitted
chore: better tools arrangement (#509)
1 parent e23d8ac commit ebbfc14

File tree

11 files changed

+104
-30
lines changed

11 files changed

+104
-30
lines changed

src/codegen/extensions/langchain/agent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from .tools import (
1414
CommitTool,
1515
CreateFileTool,
16-
CreatePRCommentTool,
17-
CreatePRReviewCommentTool,
18-
CreatePRTool,
1916
DeleteFileTool,
2017
EditFileTool,
21-
GetPRcontentsTool,
18+
GithubCreatePRCommentTool,
19+
GithubCreatePRReviewCommentTool,
20+
GithubCreatePRTool,
21+
GithubViewPRTool,
2222
ListDirectoryTool,
2323
MoveSymbolTool,
2424
RenameFileTool,
@@ -68,10 +68,10 @@ def create_codebase_agent(
6868
SemanticEditTool(codebase),
6969
SemanticSearchTool(codebase),
7070
CommitTool(codebase),
71-
CreatePRTool(codebase),
72-
GetPRcontentsTool(codebase),
73-
CreatePRCommentTool(codebase),
74-
CreatePRReviewCommentTool(codebase),
71+
GithubCreatePRTool(codebase),
72+
GithubViewPRTool(codebase),
73+
GithubCreatePRCommentTool(codebase),
74+
GithubCreatePRReviewCommentTool(codebase),
7575
]
7676

7777
# Get the prompt to use

src/codegen/extensions/langchain/tools.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from codegen import Codebase
1010
from codegen.extensions.linear.linear_client import LinearClient
11-
from codegen.extensions.tools.linear_tools import (
11+
from codegen.extensions.tools.linear.linear import (
1212
linear_comment_on_issue_tool,
1313
linear_create_issue_tool,
1414
linear_get_issue_comments_tool,
@@ -354,19 +354,24 @@ def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str:
354354
return json.dumps(result, indent=2)
355355

356356

357-
class CreatePRInput(BaseModel):
357+
########################################################################################################################
358+
# GITHUB
359+
########################################################################################################################
360+
361+
362+
class GithubCreatePRInput(BaseModel):
358363
"""Input for creating a PR"""
359364

360365
title: str = Field(..., description="The title of the PR")
361366
body: str = Field(..., description="The body of the PR")
362367

363368

364-
class CreatePRTool(BaseTool):
369+
class GithubCreatePRTool(BaseTool):
365370
"""Tool for creating a PR."""
366371

367372
name: ClassVar[str] = "create_pr"
368373
description: ClassVar[str] = "Create a PR for the current branch"
369-
args_schema: ClassVar[type[BaseModel]] = CreatePRInput
374+
args_schema: ClassVar[type[BaseModel]] = GithubCreatePRInput
370375
codebase: Codebase = Field(exclude=True)
371376

372377
def __init__(self, codebase: Codebase) -> None:
@@ -377,18 +382,18 @@ def _run(self, title: str, body: str) -> str:
377382
return json.dumps(result, indent=2)
378383

379384

380-
class GetPRContentsInput(BaseModel):
385+
class GithubViewPRInput(BaseModel):
381386
"""Input for getting PR contents."""
382387

383388
pr_id: int = Field(..., description="Number of the PR to get the contents for")
384389

385390

386-
class GetPRcontentsTool(BaseTool):
391+
class GithubViewPRTool(BaseTool):
387392
"""Tool for getting PR data."""
388393

389-
name: ClassVar[str] = "get_pr_contents"
390-
description: ClassVar[str] = "Get the diff and modified symbols of a PR along with the dependencies of the modified symbols"
391-
args_schema: ClassVar[type[BaseModel]] = GetPRContentsInput
394+
name: ClassVar[str] = "view_pr"
395+
description: ClassVar[str] = "View the diff and associated context for a pull request"
396+
args_schema: ClassVar[type[BaseModel]] = GithubViewPRInput
392397
codebase: Codebase = Field(exclude=True)
393398

394399
def __init__(self, codebase: Codebase) -> None:
@@ -399,19 +404,19 @@ def _run(self, pr_id: int) -> str:
399404
return json.dumps(result, indent=2)
400405

401406

402-
class CreatePRCommentInput(BaseModel):
407+
class GithubCreatePRCommentInput(BaseModel):
403408
"""Input for creating a PR comment"""
404409

405410
pr_number: int = Field(..., description="The PR number to comment on")
406411
body: str = Field(..., description="The comment text")
407412

408413

409-
class CreatePRCommentTool(BaseTool):
414+
class GithubCreatePRCommentTool(BaseTool):
410415
"""Tool for creating a general PR comment."""
411416

412417
name: ClassVar[str] = "create_pr_comment"
413418
description: ClassVar[str] = "Create a general comment on a pull request"
414-
args_schema: ClassVar[type[BaseModel]] = CreatePRCommentInput
419+
args_schema: ClassVar[type[BaseModel]] = GithubCreatePRCommentInput
415420
codebase: Codebase = Field(exclude=True)
416421

417422
def __init__(self, codebase: Codebase) -> None:
@@ -422,7 +427,7 @@ def _run(self, pr_number: int, body: str) -> str:
422427
return json.dumps(result, indent=2)
423428

424429

425-
class CreatePRReviewCommentInput(BaseModel):
430+
class GithubCreatePRReviewCommentInput(BaseModel):
426431
"""Input for creating an inline PR review comment"""
427432

428433
pr_number: int = Field(..., description="The PR number to comment on")
@@ -434,12 +439,12 @@ class CreatePRReviewCommentInput(BaseModel):
434439
start_line: int | None = Field(None, description="For multi-line comments, the starting line")
435440

436441

437-
class CreatePRReviewCommentTool(BaseTool):
442+
class GithubCreatePRReviewCommentTool(BaseTool):
438443
"""Tool for creating inline PR review comments."""
439444

440445
name: ClassVar[str] = "create_pr_review_comment"
441446
description: ClassVar[str] = "Create an inline review comment on a specific line in a pull request"
442-
args_schema: ClassVar[type[BaseModel]] = CreatePRReviewCommentInput
447+
args_schema: ClassVar[type[BaseModel]] = GithubCreatePRReviewCommentInput
443448
codebase: Codebase = Field(exclude=True)
444449

445450
def __init__(self, codebase: Codebase) -> None:
@@ -468,6 +473,11 @@ def _run(
468473
return json.dumps(result, indent=2)
469474

470475

476+
########################################################################################################################
477+
# LINEAR
478+
########################################################################################################################
479+
480+
471481
class LinearGetIssueInput(BaseModel):
472482
"""Input for getting a Linear issue."""
473483

@@ -597,6 +607,11 @@ def _run(self) -> str:
597607
return json.dumps(result, indent=2)
598608

599609

610+
########################################################################################################################
611+
# EXPORT
612+
########################################################################################################################
613+
614+
600615
def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
601616
"""Get all workspace tools initialized with a codebase.
602617
@@ -609,12 +624,9 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
609624
return [
610625
CommitTool(codebase),
611626
CreateFileTool(codebase),
612-
CreatePRTool(codebase),
613-
CreatePRCommentTool(codebase),
614-
CreatePRReviewCommentTool(codebase),
615627
DeleteFileTool(codebase),
616628
EditFileTool(codebase),
617-
GetPRcontentsTool(codebase),
629+
GithubViewPRTool(codebase),
618630
ListDirectoryTool(codebase),
619631
MoveSymbolTool(codebase),
620632
RenameFileTool(codebase),
@@ -623,6 +635,12 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
623635
SemanticEditTool(codebase),
624636
SemanticSearchTool(codebase),
625637
ViewFileTool(codebase),
638+
# Github
639+
GithubCreatePRTool(codebase),
640+
GithubCreatePRCommentTool(codebase),
641+
GithubCreatePRReviewCommentTool(codebase),
642+
GithubViewPRTool(codebase),
643+
# Linear
626644
LinearGetIssueTool(codebase),
627645
LinearGetIssueCommentsTool(codebase),
628646
LinearCommentOnIssueTool(codebase),

src/codegen/extensions/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .github.create_pr_comment import create_pr_comment
99
from .github.create_pr_review_comment import create_pr_review_comment
1010
from .github.view_pr import view_pr
11-
from .linear_tools import (
11+
from .linear import (
1212
linear_comment_on_issue_tool,
1313
linear_get_issue_comments_tool,
1414
linear_get_issue_tool,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .create_pr import create_pr
2+
from .create_pr_comment import create_pr_comment
3+
from .create_pr_review_comment import create_pr_review_comment
4+
from .view_pr import view_pr
5+
6+
__all__ = [
7+
"create_pr",
8+
"create_pr_comment",
9+
"create_pr_review_comment",
10+
"view_pr",
11+
]

src/codegen/extensions/tools/github/view_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ def view_pr(codebase: Codebase, pr_id: int) -> dict[str, Any]:
1818
modified_symbols, patch = codebase.get_modified_symbols_in_pr(pr_id)
1919

2020
# Convert modified_symbols set to list for JSON serialization
21-
return {"status": "success", "modified_symbols": list(modified_symbols), "patch": patch}
21+
return {"status": "success", "patch": patch}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .linear import (
2+
linear_comment_on_issue_tool,
3+
linear_create_issue_tool,
4+
linear_get_issue_comments_tool,
5+
linear_get_issue_tool,
6+
linear_get_teams_tool,
7+
linear_register_webhook_tool,
8+
linear_search_issues_tool,
9+
)
10+
11+
__all__ = [
12+
"linear_comment_on_issue_tool",
13+
"linear_create_issue_tool",
14+
"linear_get_issue_comments_tool",
15+
"linear_get_issue_tool",
16+
"linear_get_teams_tool",
17+
"linear_register_webhook_tool",
18+
"linear_search_issues_tool",
19+
]

tests/integration/__init__.py

Whitespace-only changes.

tests/integration/extension/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Tests for Linear tools."""
2+
3+
import os
4+
5+
import pytest
6+
7+
from codegen import Codebase
8+
from codegen.extensions.linear.linear_client import LinearClient
9+
from codegen.extensions.tools.github import view_pr
10+
11+
12+
@pytest.fixture
13+
def client() -> LinearClient:
14+
"""Create a Linear client for testing."""
15+
token = os.getenv("CODEGEN_SECRETS__GITHUB_TOKEN")
16+
if not token:
17+
pytest.skip("CODEGEN_SECRETS__GITHUB_TOKEN environment variable not set")
18+
codebase = Codebase.from_repo("codegen-sh/Kevin-s-Adventure-Game")
19+
return codebase
20+
21+
22+
def test_github_view_pr(client: LinearClient) -> None:
23+
"""Test getting an issue from Linear."""
24+
# Link to PR: https://github.com/codegen-sh/Kevin-s-Adventure-Game/pull/419
25+
pr = view_pr(client, 419)
26+
print(pr)

tests/integration/extension/test_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from codegen.extensions.linear.linear_client import LinearClient
8-
from codegen.extensions.tools.linear_tools import (
8+
from codegen.extensions.tools.linear.linear import (
99
linear_comment_on_issue_tool,
1010
linear_create_issue_tool,
1111
linear_get_issue_comments_tool,

0 commit comments

Comments
 (0)