Skip to content

Commit c5ba736

Browse files
chore: [CG-10339] support codebase create_pr (#420)
1 parent 52396cc commit c5ba736

File tree

4 files changed

+61
-50
lines changed

4 files changed

+61
-50
lines changed

src/codegen/git/repo_operator/local_repo_operator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
from typing import Self, override
55

66
from codeowners import CodeOwners as CodeOwnersParser
7-
from git import Remote
87
from git import Repo as GitCLI
9-
from git.remote import PushInfoList
108
from github import Github
119
from github.PullRequest import PullRequest
10+
from github.Repository import Repository
1211

1312
from codegen.git.clients.git_repo_client import GitRepoClient
1413
from codegen.git.repo_operator.repo_operator import RepoOperator
1514
from codegen.git.schemas.enums import FetchResult
1615
from codegen.git.schemas.repo_config import RepoConfig
1716
from codegen.git.utils.clone_url import url_to_github
1817
from codegen.git.utils.file_utils import create_files
18+
from codegen.shared.configs.config import config
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -41,7 +41,7 @@ def __init__(
4141
github_api_key: str | None = None,
4242
bot_commit: bool = False,
4343
) -> None:
44-
self._github_api_key = github_api_key
44+
self._github_api_key = github_api_key or config.secrets.github_token
4545
self._remote_git_repo = None
4646
super().__init__(repo_config, bot_commit)
4747
os.makedirs(self.repo_path, exist_ok=True)
@@ -52,7 +52,7 @@ def __init__(
5252
####################################################################################################################
5353

5454
@property
55-
def remote_git_repo(self) -> GitRepoClient:
55+
def remote_git_repo(self) -> Repository:
5656
if self._remote_git_repo is None:
5757
if not self._github_api_key:
5858
return None
@@ -173,10 +173,6 @@ def base_url(self) -> str | None:
173173
if remote := next(iter(self.git_cli.remotes), None):
174174
return url_to_github(remote.url, self.get_active_branch_or_commit())
175175

176-
@override
177-
def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList:
178-
raise OperatorIsLocal()
179-
180176
@override
181177
def pull_repo(self) -> None:
182178
"""Pull the latest commit down to an existing local repo"""

src/codegen/git/repo_operator/remote_repo_operator.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from typing import override
55

66
from codeowners import CodeOwners as CodeOwnersParser
7-
from git import GitCommandError, Remote
8-
from git.remote import PushInfoList
7+
from git import GitCommandError
98

109
from codegen.git.clients.git_repo_client import GitRepoClient
1110
from codegen.git.repo_operator.repo_operator import RepoOperator
@@ -165,43 +164,6 @@ def checkout_remote_branch(self, branch_name: str | None = None, remote_name: st
165164
"""
166165
return self.checkout_branch(branch_name, remote_name=remote_name, remote=True, create_if_missing=False)
167166

168-
@stopwatch
169-
def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList:
170-
"""Push the changes to the given refspec of the remote.
171-
172-
Args:
173-
refspec (str | None): refspec to push. If None, the current active branch is used.
174-
remote (Remote | None): Remote to push too. Defaults to 'origin'.
175-
force (bool): If True, force push the changes. Defaults to False.
176-
"""
177-
# Use default remote if not provided
178-
if not remote:
179-
remote = self.git_cli.remote(name="origin")
180-
181-
# Use the current active branch if no branch is specified
182-
if not refspec:
183-
# TODO: doesn't work with detached HEAD state
184-
refspec = self.git_cli.active_branch.name
185-
186-
res = remote.push(refspec=refspec, force=force, progress=CustomRemoteProgress())
187-
for push_info in res:
188-
if push_info.flags & push_info.ERROR:
189-
# Handle the error case
190-
logger.warning(f"Error pushing {refspec}: {push_info.summary}")
191-
elif push_info.flags & push_info.FAST_FORWARD:
192-
# Successful fast-forward push
193-
logger.info(f"{refspec} pushed successfully (fast-forward).")
194-
elif push_info.flags & push_info.NEW_HEAD:
195-
# Successful push of a new branch
196-
logger.info(f"{refspec} pushed successfully as a new branch.")
197-
elif push_info.flags & push_info.NEW_TAG:
198-
# Successful push of a new tag (if relevant)
199-
logger.info("New tag pushed successfully.")
200-
else:
201-
# Successful push, general case
202-
logger.info(f"{refspec} pushed successfully.")
203-
return res
204-
205167
@cached_property
206168
def base_url(self) -> str | None:
207169
repo_config = self.repo_config

src/codegen/git/repo_operator/repo_operator.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME
1818
from codegen.git.schemas.enums import CheckoutResult, FetchResult
1919
from codegen.git.schemas.repo_config import RepoConfig
20+
from codegen.git.utils.remote_progress import CustomRemoteProgress
2021
from codegen.shared.performance.stopwatch_utils import stopwatch
2122
from codegen.shared.performance.time_utils import humanize_duration
2223

@@ -137,7 +138,17 @@ def git_diff(self) -> str:
137138

138139
@property
139140
def default_branch(self) -> str:
140-
return self._default_branch or self.git_cli.active_branch.name
141+
# Priority 1: If default branch has been set
142+
if self._default_branch:
143+
return self._default_branch
144+
145+
# Priority 2: If origin/HEAD ref exists
146+
origin_prefix = "origin"
147+
if f"{origin_prefix}/HEAD" in self.git_cli.refs:
148+
return self.git_cli.refs[f"{origin_prefix}/HEAD"].reference.name.removeprefix(f"{origin_prefix}/")
149+
150+
# Priority 3: Fallback to the active branch
151+
return self.git_cli.active_branch.name
141152

142153
@abstractmethod
143154
def codeowners_parser(self) -> CodeOwnersParser | None: ...
@@ -372,14 +383,42 @@ def commit_changes(self, message: str, verify: bool = False) -> bool:
372383
logger.info("No changes to commit. Do nothing.")
373384
return False
374385

375-
@abstractmethod
386+
@stopwatch
376387
def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList:
377-
"""Push the changes to the given refspec of the remote repository.
388+
"""Push the changes to the given refspec of the remote.
378389
379390
Args:
380391
refspec (str | None): refspec to push. If None, the current active branch is used.
381392
remote (Remote | None): Remote to push too. Defaults to 'origin'.
393+
force (bool): If True, force push the changes. Defaults to False.
382394
"""
395+
# Use default remote if not provided
396+
if not remote:
397+
remote = self.git_cli.remote(name="origin")
398+
399+
# Use the current active branch if no branch is specified
400+
if not refspec:
401+
# TODO: doesn't work with detached HEAD state
402+
refspec = self.git_cli.active_branch.name
403+
404+
res = remote.push(refspec=refspec, force=force, progress=CustomRemoteProgress())
405+
for push_info in res:
406+
if push_info.flags & push_info.ERROR:
407+
# Handle the error case
408+
logger.warning(f"Error pushing {refspec}: {push_info.summary}")
409+
elif push_info.flags & push_info.FAST_FORWARD:
410+
# Successful fast-forward push
411+
logger.info(f"{refspec} pushed successfully (fast-forward).")
412+
elif push_info.flags & push_info.NEW_HEAD:
413+
# Successful push of a new branch
414+
logger.info(f"{refspec} pushed successfully as a new branch.")
415+
elif push_info.flags & push_info.NEW_TAG:
416+
# Successful push of a new tag (if relevant)
417+
logger.info("New tag pushed successfully.")
418+
else:
419+
# Successful push, general case
420+
logger.info(f"{refspec} pushed successfully.")
421+
return res
383422

384423
def relpath(self, abspath) -> str:
385424
# TODO: check if the path is an abspath (i.e. contains self.repo_path)

src/codegen/sdk/core/codebase.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from git import Commit as GitCommit
1616
from git import Diff
1717
from git.remote import PushInfoList
18+
from github.PullRequest import PullRequest
1819
from networkx import Graph
1920
from rich.console import Console
2021
from typing_extensions import deprecated
@@ -872,6 +873,19 @@ def restore_stashed_changes(self):
872873
"""Restore the most recent stash in the codebase."""
873874
self._op.stash_pop()
874875

876+
####################################################################################################################
877+
# GITHUB
878+
####################################################################################################################
879+
880+
def create_pr(self, title: str, body: str) -> PullRequest:
881+
"""Creates a PR from the current branch."""
882+
if self._op.git_cli.head.is_detached:
883+
msg = "Cannot make a PR from a detached HEAD"
884+
raise ValueError(msg)
885+
self._op.stage_and_commit_all_changes(message=title)
886+
self._op.push_changes()
887+
return self._op.remote_git_repo.create_pull(head=self._op.git_cli.active_branch.name, base=self._op.default_branch, title=title, body=body)
888+
875889
####################################################################################################################
876890
# GRAPH VISUALIZATION
877891
####################################################################################################################

0 commit comments

Comments
 (0)