Skip to content

feat: final set of upgrades for tools #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/codegen/extensions/tools/github/view_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class ViewPRObservation(Observation):
file_commit_sha: dict[str, str] = Field(
description="Commit SHAs for each file in the PR",
)
modified_symbols: list[str] = Field(
description="Names of modified symbols in the PR",
)

str_template: ClassVar[str] = "PR #{pr_id}"

Expand All @@ -33,13 +36,14 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
pr_id: Number of the PR to get the contents for
"""
try:
patch, file_commit_sha = codebase.get_modified_symbols_in_pr(pr_id)
patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id)

return ViewPRObservation(
status="success",
pr_id=pr_id,
patch=patch,
file_commit_sha=file_commit_sha,
modified_symbols=moddified_symbols,
)

except Exception as e:
Expand All @@ -49,4 +53,5 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
pr_id=pr_id,
patch="",
file_commit_sha={},
modified_symbols=[],
)
7 changes: 4 additions & 3 deletions src/codegen/git/utils/pr_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from codegen.git.repo_operator.repo_operator import RepoOperator

if TYPE_CHECKING:
from codegen.sdk.core.codebase import Codebase, Editable, File, Symbol
from codegen.sdk.core.codebase import Codebase, Editable, File

Check failure on line 12 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Module "codegen.sdk.core.codebase" has no attribute "Codebase" [attr-defined]


def get_merge_base(git_repo_client: Repository, pull: PullRequest | PullRequestContext) -> str:

Check failure on line 15 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Module "github.Repository" is not valid as a type [valid-type]
"""Gets the merge base of a pull request using a remote GitHub API client.

Args:
Expand All @@ -22,7 +22,7 @@
Returns:
str: The SHA of the merge base commit.
"""
comparison = git_repo_client.compare(pull.base.sha, pull.head.sha)

Check failure on line 25 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Repository? has no attribute "compare" [attr-defined]
return comparison.merge_base_commit.sha


Expand All @@ -46,7 +46,7 @@
raise ValueError(msg)

# Get the diff directly from the PR
diff_url = pull.raw_data.get("diff_url")

Check failure on line 49 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: "PullRequestContext" has no attribute "raw_data" [attr-defined]
if diff_url:
# Fetch the diff content from the URL
response = requests.get(diff_url)
Expand All @@ -54,7 +54,7 @@
diff = response.text
else:
# If diff_url not available, get the patch directly
diff = pull.get_patch()

Check failure on line 57 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: "PullRequestContext" has no attribute "get_patch" [attr-defined]

# Parse the diff into a PatchSet
pull_patch_set = PatchSet(diff)
Expand Down Expand Up @@ -120,7 +120,7 @@
_op: RepoOperator

# =====[ Computed ]=====
_modified_file_ranges: dict[str, list[tuple[int, int]]] = None

Check failure on line 123 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "None", variable has type "dict[str, list[tuple[int, int]]]") [assignment]

def __init__(self, op: RepoOperator, codebase: "Codebase", pr: PullRequest):
self._op = op
Expand All @@ -131,7 +131,7 @@
def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]:
"""Files and the ranges within that are modified"""
if not self._modified_file_ranges:
pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr)

Check failure on line 134 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument "pull" to "get_pull_patch_set" has incompatible type "PullRequest"; expected "PullRequestContext" [arg-type]
self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set)
return self._modified_file_ranges

Expand All @@ -145,12 +145,12 @@
filepath = editable.filepath
changed_ranges = self._modified_file_ranges.get(filepath, [])
symbol_range = to_1_indexed(editable.line_range)
if any(overlaps(symbol_range, changed_range) for changed_range in changed_ranges):

Check failure on line 148 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 2 to "overlaps" has incompatible type "tuple[int, int]"; expected "range" [arg-type]
return True
return False

@property
def modified_symbols(self) -> list["Symbol"]:
def modified_symbols(self) -> list[str]:
# Import SourceFile locally to avoid circular dependencies
from codegen.sdk.core.file import SourceFile

Expand All @@ -163,7 +163,8 @@
continue
for symbol in file.symbols:
if self.is_modified(symbol):
all_modified.append(symbol)
all_modified.append(symbol.name)

return all_modified

def get_pr_diff(self) -> str:
Expand All @@ -181,7 +182,7 @@
return response.text
else:
# If diff_url not available, get the patch directly
return self._gh_pr.get_patch()

Check failure on line 185 in src/codegen/git/utils/pr_review.py

View workflow job for this annotation

GitHub Actions / mypy

error: "PullRequest" has no attribute "get_patch" [attr-defined]

def get_commit_sha(self) -> str:
"""Get the commit SHA of the PR"""
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
TExport = TypeVar("TExport", bound="Export", default=Export)
TSGlobalVar = TypeVar("TSGlobalVar", bound="Assignment", default=Assignment)
PyGlobalVar = TypeVar("PyGlobalVar", bound="Assignment", default=Assignment)
TSDirectory = Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]

Check failure on line 107 in src/codegen/sdk/core/codebase.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot resolve name "TSDirectory" (possible cyclic definition) [misc]
PyDirectory = Directory[PyFile, PySymbol, PyImportStatement, PyGlobalVar, PyClass, PyFunction, PyImport]


Expand Down Expand Up @@ -1311,13 +1311,13 @@
logger.exception(f"Failed to initialize codebase: {e}")
raise

def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str]]:
def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]:
"""Get all modified symbols in a pull request"""
pr = self._op.get_pull_request(pr_id)
cg_pr = CodegenPR(self._op, self, pr)
patch = cg_pr.get_pr_diff()
commit_sha = cg_pr.get_file_commit_shas()
return patch, commit_sha
return patch, commit_sha, cg_pr.modified_symbols

def create_pr_comment(self, pr_number: int, body: str) -> None:
"""Create a comment on a pull request"""
Expand Down