Skip to content

Commit 572d9ab

Browse files
authored
chore: pr-review-tooling (#434)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: kopekC <[email protected]>
1 parent e451605 commit 572d9ab

File tree

5 files changed

+32
-10
lines changed

5 files changed

+32
-10
lines changed

codegen-examples/examples/pr_review_bot/run.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import codegen
22
from codegen import Codebase
3-
from codegen.sdk.enums import ProgrammingLanguage
4-
from codegen.sdk.codebase.config import CodebaseConfig
3+
from codegen.shared.enums.programming_language import ProgrammingLanguage
4+
from codegen.sdk.codebase.config import CodebaseConfig, Secrets
55
import json
66

77
from codegen.sdk.secrets import Secrets
@@ -20,13 +20,9 @@ def run(codebase: Codebase):
2020
modified_symbols = codebase.get_modified_symbols_in_pr(pr_number)
2121
for symbol in modified_symbols:
2222
# Get direct dependencies
23-
deps = codebase.get_symbol_dependencies(symbol, max_depth=2)
23+
deps = symbol.dependencies(max_depth=2)
2424
context_symbols.update(deps)
2525

26-
# Get reverse dependencies (symbols that depend on this one)
27-
rev_deps = codebase.get_symbol_dependents(symbol, max_depth=2)
28-
context_symbols.update(rev_deps)
29-
3026
# Prepare context for LLM
3127
context = {
3228
"modified_symbols": [

src/codegen/git/repo_operator/local_repo_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_pull_request(self, pr_number: int) -> PullRequest | None:
183183
if repo is None:
184184
logger.warning("GitHub API key is required to fetch pull requests")
185185
return None
186-
return repo.get_pull(pr_number)
186+
return repo.get_pull_safe(pr_number)
187187
except Exception as e:
188188
logger.warning(f"Failed to get PR {pr_number}: {e!s}")
189189
return None

src/codegen/git/repo_operator/repo_operator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,11 @@ def stash_push(self) -> None:
600600

601601
def stash_pop(self) -> None:
602602
self.git_cli.git.stash("pop")
603+
604+
####################################################################################################################
605+
# PR UTILITIES
606+
####################################################################################################################
607+
608+
def get_pr_data(self, pr_number: int) -> dict:
609+
"""Returns the data associated with a PR"""
610+
return self.remote_git_repo.get_pr_data(pr_number)

src/codegen/git/utils/pr_review.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,20 @@ def modified_symbols(self) -> list["Symbol"]:
127127
if self.is_modified(symbol):
128128
all_modified.append(symbol)
129129
return all_modified
130+
131+
def get_pr_diff(self) -> str:
132+
"""Get the full diff of the PR"""
133+
if not self._op.remote_git_repo:
134+
msg = "GitHub API client is required to get PR diffs"
135+
raise ValueError(msg)
136+
137+
# Get the diff directly from the PR
138+
diff_url = self._gh_pr.raw_data.get("diff_url")
139+
if diff_url:
140+
# Fetch the diff content from the URL
141+
response = requests.get(diff_url)
142+
response.raise_for_status()
143+
return response.text
144+
else:
145+
# If diff_url not available, get the patch directly
146+
return self._gh_pr.get_patch()

src/codegen/sdk/core/codebase.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,11 +1259,12 @@ def from_repo(
12591259
logger.exception(f"Failed to initialize codebase: {e}")
12601260
raise
12611261

1262-
def get_modified_symbols_in_pr(self, pr_id: int) -> list[Symbol]:
1262+
def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[list[Symbol], str]:
12631263
"""Get all modified symbols in a pull request"""
12641264
pr = self._op.get_pull_request(pr_id)
12651265
cg_pr = CodegenPR(self._op, self, pr)
1266-
return cg_pr.modified_symbols
1266+
patch = cg_pr.get_pr_diff()
1267+
return cg_pr.modified_symbols, patch
12671268

12681269

12691270
# The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py

0 commit comments

Comments
 (0)