Skip to content

Commit 401204a

Browse files
authored
Foundations for PR BOT static analisis (#343)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [x] I have added tests for my changes - [x] I have updated the documentation or added new documentation as needed --------- Co-authored-by: kopekC <[email protected]>
1 parent e190971 commit 401204a

File tree

5 files changed

+232
-17
lines changed

5 files changed

+232
-17
lines changed

docs/mint.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"og:locale": "en_US",
1717
"og:logo": "https://i.imgur.com/f4OVOqI.png",
1818
"article:publisher": "Codegen, Inc.",
19-
"twitter:site": "@codegen",
19+
"twitter:site": "@codegen"
2020
},
2121
"favicon": "/favicon.svg",
2222
"colors": {

src/codegen/git/repo_operator/local_repo_operator.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
from functools import cached_property
34
from typing import Self, override
@@ -6,13 +7,19 @@
67
from git import Remote
78
from git import Repo as GitCLI
89
from git.remote import PushInfoList
10+
from github import Github
11+
from github.PullRequest import PullRequest
912

13+
from codegen.git.clients.git_repo_client import GitRepoClient
1014
from codegen.git.repo_operator.repo_operator import RepoOperator
1115
from codegen.git.schemas.enums import FetchResult
16+
from codegen.git.schemas.github import GithubType
1217
from codegen.git.schemas.repo_config import BaseRepoConfig
1318
from codegen.git.utils.clone_url import url_to_github
1419
from codegen.git.utils.file_utils import create_files
1520

21+
logger = logging.getLogger(__name__)
22+
1623

1724
class OperatorIsLocal(Exception):
1825
"""Error raised while trying to do a remote operation on a local operator"""
@@ -29,20 +36,54 @@ class LocalRepoOperator(RepoOperator):
2936
_repo_name: str
3037
_git_cli: GitCLI
3138
repo_config: BaseRepoConfig
39+
_github_api_key: str | None
40+
_remote_git_repo: GitRepoClient | None = None
3241

3342
def __init__(
3443
self,
3544
repo_path: str, # full path to the repo
45+
github_api_key: str | None = None,
3646
repo_config: BaseRepoConfig | None = None,
3747
bot_commit: bool = False,
3848
) -> None:
3949
self._repo_path = repo_path
4050
self._repo_name = os.path.basename(repo_path)
51+
self._github_api_key = github_api_key
52+
self.github_type = GithubType.Github
53+
self._remote_git_repo = None
4154
os.makedirs(self.repo_path, exist_ok=True)
4255
GitCLI.init(self.repo_path)
4356
repo_config = repo_config or BaseRepoConfig()
4457
super().__init__(repo_config, self.repo_path, bot_commit)
4558

59+
####################################################################################################################
60+
# PROPERTIES
61+
####################################################################################################################
62+
63+
@property
64+
def remote_git_repo(self) -> GitRepoClient:
65+
if self._remote_git_repo is None:
66+
if not self._github_api_key:
67+
return None
68+
69+
if not (base_url := self.base_url):
70+
msg = "Could not determine GitHub URL from remotes"
71+
raise ValueError(msg)
72+
73+
# Extract owner and repo from the base URL
74+
# Format: https://github.com/owner/repo
75+
parts = base_url.split("/")
76+
if len(parts) < 2:
77+
msg = f"Invalid GitHub URL format: {base_url}"
78+
raise ValueError(msg)
79+
80+
owner = parts[-4]
81+
repo = parts[-3]
82+
83+
github = Github(self._github_api_key)
84+
self._remote_git_repo = github.get_repo(f"{owner}/{repo}")
85+
return self._remote_git_repo
86+
4687
####################################################################################################################
4788
# CLASS METHODS
4889
####################################################################################################################
@@ -70,9 +111,16 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo
70111
return op
71112

72113
@classmethod
73-
def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self:
74-
"""Do a shallow checkout of a particular commit to get a repository from a given remote URL."""
75-
op = cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False)
114+
def create_from_commit(cls, repo_path: str, commit: str, url: str, github_api_key: str | None = None) -> Self:
115+
"""Do a shallow checkout of a particular commit to get a repository from a given remote URL.
116+
117+
Args:
118+
repo_path (str): Path where the repo should be cloned
119+
commit (str): The commit hash to checkout
120+
url (str): Git URL of the repository
121+
github_api_key (str | None): Optional GitHub API key for operations that need GitHub access
122+
"""
123+
op = cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key)
76124
op.discard_changes()
77125
if op.get_active_branch_or_commit() != commit:
78126
op.create_remote("origin", url)
@@ -81,12 +129,13 @@ def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self:
81129
return op
82130

83131
@classmethod
84-
def create_from_repo(cls, repo_path: str, url: str) -> Self:
132+
def create_from_repo(cls, repo_path: str, url: str, github_api_key: str | None = None) -> Self:
85133
"""Create a fresh clone of a repository or use existing one if up to date.
86134
87135
Args:
88136
repo_path (str): Path where the repo should be cloned
89137
url (str): Git URL of the repository
138+
github_api_key (str | None): Optional GitHub API key for operations that need GitHub access
90139
"""
91140
# Check if repo already exists
92141
if os.path.exists(repo_path):
@@ -102,7 +151,7 @@ def create_from_repo(cls, repo_path: str, url: str) -> Self:
102151
remote_head = git_cli.remotes.origin.refs[git_cli.active_branch.name].commit
103152
# If up to date, use existing repo
104153
if local_head.hexsha == remote_head.hexsha:
105-
return cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False)
154+
return cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key)
106155
except Exception:
107156
# If any git operations fail, fallback to fresh clone
108157
pass
@@ -113,13 +162,13 @@ def create_from_repo(cls, repo_path: str, url: str) -> Self:
113162

114163
shutil.rmtree(repo_path)
115164

116-
# Do a fresh clone with depth=1 to get latest commit
165+
# Clone the repository
117166
GitCLI.clone_from(url=url, to_path=repo_path, depth=1)
118167

119168
# Initialize with the cloned repo
120169
git_cli = GitCLI(repo_path)
121170

122-
return cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False)
171+
return cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key)
123172

124173
####################################################################################################################
125174
# PROPERTIES
@@ -153,3 +202,26 @@ def pull_repo(self) -> None:
153202

154203
def fetch_remote(self, remote_name: str = "origin", refspec: str | None = None, force: bool = True) -> FetchResult:
155204
raise OperatorIsLocal()
205+
206+
def get_pull_request(self, pr_number: int) -> PullRequest | None:
207+
"""Get a GitHub Pull Request object for the given PR number.
208+
209+
Args:
210+
pr_number (int): The PR number to fetch
211+
212+
Returns:
213+
PullRequest | None: The PyGitHub PullRequest object if found, None otherwise
214+
215+
Note:
216+
This requires a GitHub API key to be set when creating the LocalRepoOperator
217+
"""
218+
try:
219+
# Create GitHub client and get the PR
220+
repo = self.remote_git_repo
221+
if repo is None:
222+
logger.warning("GitHub API key is required to fetch pull requests")
223+
return None
224+
return repo.get_pull(pr_number)
225+
except Exception as e:
226+
logger.warning(f"Failed to get PR {pr_number}: {e!s}")
227+
return None

src/codegen/git/utils/pr_review.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import TYPE_CHECKING
2+
3+
import requests
4+
from github import Repository
5+
from github.PullRequest import PullRequest
6+
from unidiff import PatchSet
7+
8+
from codegen.git.models.pull_request_context import PullRequestContext
9+
from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator
10+
from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator
11+
12+
if TYPE_CHECKING:
13+
from codegen.sdk.core.codebase import Codebase, Editable, File, Symbol
14+
15+
16+
def get_merge_base(git_repo_client: Repository, pull: PullRequest | PullRequestContext) -> str:
17+
"""Gets the merge base of a pull request using a remote GitHub API client.
18+
19+
Args:
20+
git_repo_client (GitRepoClient): The GitHub repository client.
21+
pull (PullRequest): The pull request object.
22+
23+
Returns:
24+
str: The SHA of the merge base commit.
25+
"""
26+
comparison = git_repo_client.compare(pull.base.sha, pull.head.sha)
27+
return comparison.merge_base_commit.sha
28+
29+
30+
def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]:
31+
file_to_changed_ranges = {}
32+
for patched_file in pull_patch_set:
33+
# TODO: skip is deleted
34+
if patched_file.is_removed_file:
35+
continue
36+
changed_ranges = [] # list of changed lines for the file
37+
for hunk in patched_file:
38+
changed_ranges.append(range(hunk.target_start, hunk.target_start + hunk.target_length))
39+
file_to_changed_ranges[patched_file.path] = changed_ranges
40+
return file_to_changed_ranges
41+
42+
43+
def get_pull_patch_set(op: LocalRepoOperator | RemoteRepoOperator, pull: PullRequestContext) -> PatchSet:
44+
# Get the diff directly from GitHub's API
45+
if not op.remote_git_repo:
46+
msg = "GitHub API client is required to get PR diffs"
47+
raise ValueError(msg)
48+
49+
# Get the diff directly from the PR
50+
diff_url = pull.raw_data.get("diff_url")
51+
if diff_url:
52+
# Fetch the diff content from the URL
53+
response = requests.get(diff_url)
54+
response.raise_for_status()
55+
diff = response.text
56+
else:
57+
# If diff_url not available, get the patch directly
58+
diff = pull.get_patch()
59+
60+
# Parse the diff into a PatchSet
61+
pull_patch_set = PatchSet(diff)
62+
return pull_patch_set
63+
64+
65+
def to_1_indexed(zero_indexed_range: range) -> range:
66+
"""Converts a n-indexed range to n+1-indexed.
67+
Primarily to convert 0-indexed ranges to 1 indexed
68+
"""
69+
return range(zero_indexed_range.start + 1, zero_indexed_range.stop + 1)
70+
71+
72+
def overlaps(range1: range, range2: range) -> bool:
73+
"""Returns True if the two ranges overlap, False otherwise."""
74+
return max(range1.start, range2.start) < min(range1.stop, range2.stop)
75+
76+
77+
class CodegenPR:
78+
"""Wrapper around PRs - enables codemods to interact with them"""
79+
80+
_gh_pr: PullRequest
81+
_codebase: "Codebase"
82+
_op: LocalRepoOperator | RemoteRepoOperator
83+
84+
# =====[ Computed ]=====
85+
_modified_file_ranges: dict[str, list[tuple[int, int]]] = None
86+
87+
def __init__(self, op: LocalRepoOperator, codebase: "Codebase", pr: PullRequest):
88+
self._op = op
89+
self._gh_pr = pr
90+
self._codebase = codebase
91+
92+
@property
93+
def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]:
94+
"""Files and the ranges within that are modified"""
95+
if not self._modified_file_ranges:
96+
pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr)
97+
self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set)
98+
return self._modified_file_ranges
99+
100+
@property
101+
def modified_files(self) -> list["File"]:
102+
filenames = self.modified_file_ranges.keys()
103+
return [self._codebase.get_file(f, optional=True) for f in filenames]
104+
105+
def is_modified(self, editable: "Editable") -> bool:
106+
"""Returns True if the Editable's range contains any modified lines"""
107+
filepath = editable.filepath
108+
changed_ranges = self._modified_file_ranges.get(filepath, [])
109+
symbol_range = to_1_indexed(editable.line_range)
110+
if any(overlaps(symbol_range, changed_range) for changed_range in changed_ranges):
111+
return True
112+
return False
113+
114+
@property
115+
def modified_symbols(self) -> list["Symbol"]:
116+
# Import SourceFile locally to avoid circular dependencies
117+
from codegen.sdk.core.file import SourceFile
118+
119+
all_modified = []
120+
for file in self.modified_files:
121+
if file is None:
122+
print("Warning: File is None")
123+
continue
124+
if not isinstance(file, SourceFile):
125+
continue
126+
for symbol in file.symbols:
127+
if self.is_modified(symbol):
128+
all_modified.append(symbol)
129+
return all_modified

src/codegen/sdk/core/codebase.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator
2424
from codegen.git.repo_operator.repo_operator import RepoOperator
2525
from codegen.git.schemas.enums import CheckoutResult
26+
from codegen.git.utils.pr_review import CodegenPR
2627
from codegen.sdk._proxy import proxy_property
2728
from codegen.sdk.ai.helpers import AbstractAIHelper, MultiProviderAIHelper
2829
from codegen.sdk.codebase.codebase_ai import generate_system_prompt, generate_tools
@@ -112,7 +113,7 @@ class Codebase(Generic[TSourceFile, TDirectory, TSymbol, TClass, TFunction, TImp
112113
console: Manages console output for the codebase.
113114
"""
114115

115-
_op: RepoOperator | RemoteRepoOperator
116+
_op: RepoOperator | RemoteRepoOperator | LocalRepoOperator
116117
viz: VisualizationManager
117118
repo_path: Path
118119
console: Console
@@ -1162,7 +1163,16 @@ def set_session_options(self, **kwargs: Unpack[SessionOptions]) -> None:
11621163
self.G.transaction_manager.reset_stopwatch(self.G.session_options.max_seconds)
11631164

11641165
@classmethod
1165-
def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str | None = None, shallow: bool = True, programming_language: ProgrammingLanguage | None = None) -> "Codebase":
1166+
def from_repo(
1167+
cls,
1168+
repo_name: str,
1169+
*,
1170+
tmp_dir: str | None = None,
1171+
commit: str | None = None,
1172+
shallow: bool = True,
1173+
programming_language: ProgrammingLanguage | None = None,
1174+
config: CodebaseConfig = DefaultConfig,
1175+
) -> "Codebase":
11661176
"""Fetches a codebase from GitHub and returns a Codebase instance.
11671177
11681178
Args:
@@ -1171,6 +1181,7 @@ def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str |
11711181
commit (Optional[str]): The specific commit hash to clone. Defaults to HEAD
11721182
shallow (bool): Whether to do a shallow clone. Defaults to True
11731183
programming_language (ProgrammingLanguage | None): The programming language of the repo. Defaults to None.
1184+
config (CodebaseConfig): Configuration for the codebase. Defaults to DefaultConfig.
11741185
11751186
Returns:
11761187
Codebase: A Codebase instance initialized with the cloned repository
@@ -1198,26 +1209,28 @@ def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str |
11981209
# Use LocalRepoOperator to fetch the repository
11991210
logger.info("Cloning repository...")
12001211
if commit is None:
1201-
repo_operator = LocalRepoOperator.create_from_repo(repo_path=repo_path, url=repo_url)
1212+
repo_operator = LocalRepoOperator.create_from_repo(repo_path=repo_path, url=repo_url, github_api_key=config.secrets.github_api_key if config.secrets else None)
12021213
else:
12031214
# Ensure the operator can handle remote operations
1204-
repo_operator = LocalRepoOperator.create_from_commit(
1205-
repo_path=repo_path,
1206-
commit=commit,
1207-
url=repo_url,
1208-
)
1215+
repo_operator = LocalRepoOperator.create_from_commit(repo_path=repo_path, commit=commit, url=repo_url, github_api_key=config.secrets.github_api_key if config.secrets else None)
12091216
logger.info("Clone completed successfully")
12101217

12111218
# Initialize and return codebase with proper context
12121219
logger.info("Initializing Codebase...")
12131220
project = ProjectConfig.from_repo_operator(repo_operator=repo_operator, programming_language=programming_language)
1214-
codebase = Codebase(projects=[project], config=DefaultConfig)
1221+
codebase = Codebase(projects=[project], config=config)
12151222
logger.info("Codebase initialization complete")
12161223
return codebase
12171224
except Exception as e:
12181225
logger.exception(f"Failed to initialize codebase: {e}")
12191226
raise
12201227

1228+
def get_modified_symbols_in_pr(self, pr_id: int) -> list[Symbol]:
1229+
"""Get all modified symbols in a pull request"""
1230+
pr = self._op.get_pull_request(pr_id)
1231+
cg_pr = CodegenPR(self._op, self, pr)
1232+
return cg_pr.modified_symbols
1233+
12211234

12221235
# The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py
12231236
# Type Aliases

src/codegen/sdk/secrets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
@dataclass
55
class Secrets:
66
openai_key: str | None = None
7+
github_api_key: str | None = None

0 commit comments

Comments
 (0)