Skip to content

Commit 819b7b8

Browse files
authored
Merge branch 'develop' into eagarwal-cg-10796-lsp-progress-reporting-support
2 parents 3d97ee5 + ca8a42e commit 819b7b8

File tree

13 files changed

+119
-72
lines changed

13 files changed

+119
-72
lines changed

.github/actions/run-ats/action.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ runs:
4848
echo "No tests to run, skipping..."
4949
exit 0
5050
fi
51-
echo $TESTS_TO_RUN | xargs uv run pytest --cov \
51+
echo $TESTS_TO_RUN | xargs uv run --frozen pytest --cov \
5252
-o junit_suite_name="${{ github.job }}" \
5353
-n auto \
5454
-vv \
55-
--cov \
5655
--cov-append \
5756
${{ inputs.collect_args }}
5857

src/codegen/extensions/lsp/io.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@dataclass
1818
class File:
19-
doc: TextDocument
19+
doc: TextDocument | None
2020
path: Path
2121
change: TextEdit | None = None
2222
other_change: CreateFile | RenameFile | DeleteFile | None = None
@@ -65,6 +65,8 @@ def read_text(self, path: Path) -> str:
6565
return file.change.new_text
6666
if file.created:
6767
return ""
68+
if file.doc is None:
69+
return self.base_io.read_text(path)
6870
return file.doc.source
6971

7072
def read_bytes(self, path: Path) -> bytes:
@@ -76,6 +78,8 @@ def read_bytes(self, path: Path) -> bytes:
7678
return file.change.new_text.encode("utf-8")
7779
if file.created:
7880
return b""
81+
if file.doc is None:
82+
return self.base_io.read_bytes(path)
7983
return file.doc.source.encode("utf-8")
8084

8185
def write_bytes(self, path: Path, content: bytes) -> None:
@@ -112,6 +116,8 @@ def file_exists(self, path: Path) -> bool:
112116
return True
113117
if file.created:
114118
return True
119+
if file.doc is None:
120+
return self.base_io.file_exists(path)
115121
try:
116122
file.doc.source
117123
return True
@@ -134,3 +140,13 @@ def get_workspace_edit(self) -> types.WorkspaceEdit:
134140
file.change = None
135141
logger.info(f"Workspace edit: {pprint.pformat(list(map(asdict, document_changes)))}")
136142
return types.WorkspaceEdit(document_changes=document_changes)
143+
144+
def update_file(self, path: Path, version: int | None = None) -> None:
145+
file = self._get_file(path)
146+
file.doc = self.workspace.get_text_document(path.as_uri())
147+
if version is not None:
148+
file.version = version
149+
150+
def close_file(self, path: Path) -> None:
151+
file = self._get_file(path)
152+
file.doc = None

src/codegen/extensions/lsp/lsp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def did_open(server: CodegenLanguageServer, params: types.DidOpenTextDocumentPar
2424
# The document is automatically added to the workspace by pygls
2525
# We can perform any additional processing here if needed
2626
path = get_path(params.text_document.uri)
27+
server.io.update_file(path, params.text_document.version)
2728
file = server.codebase.get_file(str(path), optional=True)
2829
if not isinstance(file, SourceFile) and path.suffix in server.codebase.ctx.extensions:
2930
sync = DiffLite(change_type=ChangeType.Added, path=path)
@@ -37,6 +38,7 @@ def did_change(server: CodegenLanguageServer, params: types.DidChangeTextDocumen
3738
# The document is automatically updated in the workspace by pygls
3839
# We can perform any additional processing here if needed
3940
path = get_path(params.text_document.uri)
41+
server.io.update_file(path, params.text_document.version)
4042
sync = DiffLite(change_type=ChangeType.Modified, path=path)
4143
server.codebase.ctx.apply_diffs([sync])
4244

@@ -63,6 +65,8 @@ def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentP
6365
logger.info(f"Document closed: {params.text_document.uri}")
6466
# The document is automatically removed from the workspace by pygls
6567
# We can perform any additional cleanup here if needed
68+
path = get_path(params.text_document.uri)
69+
server.io.close_file(path)
6670

6771

6872
@server.feature(

src/codegen/git/repo_operator/repo_operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from codegen.git.schemas.enums import CheckoutResult, FetchResult
2020
from codegen.git.schemas.repo_config import RepoConfig
2121
from codegen.git.utils.remote_progress import CustomRemoteProgress
22+
from codegen.shared.configs.session_configs import config
2223
from codegen.shared.performance.stopwatch_utils import stopwatch
2324
from codegen.shared.performance.time_utils import humanize_duration
2425

@@ -46,7 +47,7 @@ def __init__(
4647
) -> None:
4748
assert repo_config is not None
4849
self.repo_config = repo_config
49-
self.access_token = access_token
50+
self.access_token = access_token or config.secrets.github_token
5051
self.base_dir = repo_config.base_dir
5152
self.bot_commit = bot_commit
5253

src/codegen/sdk/core/codebase.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def get_file_from_path(path: Path) -> File | None:
514514
if file is not None:
515515
return file
516516
absolute_path = self.ctx.to_absolute(filepath)
517-
if absolute_path.suffix in self.ctx.extensions:
517+
if absolute_path.suffix in self.ctx.extensions and not self.ctx.io.file_exists(absolute_path):
518518
return None
519519
if self.ctx.io.file_exists(absolute_path):
520520
return get_file_from_path(absolute_path)
@@ -902,10 +902,30 @@ def restore_stashed_changes(self):
902902
####################################################################################################################
903903

904904
def create_pr(self, title: str, body: str) -> PullRequest:
905-
"""Creates a PR from the current branch."""
905+
"""Creates a pull request from the current branch to the repository's default branch.
906+
907+
This method will:
908+
1. Stage and commit any pending changes with the PR title as the commit message
909+
2. Push the current branch to the remote repository
910+
3. Create a pull request targeting the default branch
911+
912+
Args:
913+
title (str): The title for the pull request
914+
body (str): The description/body text for the pull request
915+
916+
Returns:
917+
PullRequest: The created GitHub pull request object
918+
919+
Raises:
920+
ValueError: If attempting to create a PR while in a detached HEAD state
921+
ValueError: If the current branch is the default branch
922+
"""
906923
if self._op.git_cli.head.is_detached:
907924
msg = "Cannot make a PR from a detached HEAD"
908925
raise ValueError(msg)
926+
if self._op.git_cli.active_branch.name == self._op.default_branch:
927+
msg = "Cannot make a PR from the default branch"
928+
raise ValueError(msg)
909929
self._op.stage_and_commit_all_changes(message=title)
910930
self._op.push_changes()
911931
return self._op.remote_git_repo.create_pull(head_branch_name=self._op.git_cli.active_branch.name, base_branch_name=self._op.default_branch, title=title, body=body)

tests/integration/codegen/git/codebase/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def repo_config(tmpdir):
2525
def op(repo_config):
2626
os.chdir(repo_config.base_dir)
2727
GitRepo.clone_from(url=get_authenticated_clone_url_for_repo_config(repo_config, token=config.secrets.github_token), to_path=os.path.join(repo_config.base_dir, repo_config.name), depth=1)
28-
op = LocalRepoOperator(repo_config=repo_config, access_token=config.secrets.github_token)
28+
op = LocalRepoOperator(repo_config=repo_config)
2929
yield op
3030

3131

tests/integration/codegen/git/codebase/test_codebase_create_pr.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
def test_codebase_create_pr_active_branch(codebase: Codebase):
99
head = f"test-create-pr-{uuid.uuid4()}"
1010
codebase.checkout(branch=head, create_if_missing=True)
11-
codebase.files[0].remove()
11+
file = codebase.files[0]
12+
file.remove()
1213
codebase.commit()
1314
pr = codebase.create_pr(title="test-create-pr title", body="test-create-pr body")
1415
assert pr.title == "test-create-pr title"
@@ -17,10 +18,21 @@ def test_codebase_create_pr_active_branch(codebase: Codebase):
1718
assert pr.state == "open"
1819
assert pr.head.ref == head
1920
assert pr.base.ref == "main"
21+
assert pr.get_files().totalCount == 1
22+
assert pr.get_files()[0].filename == file.file_path
2023

2124

2225
def test_codebase_create_pr_detached_head(codebase: Codebase):
2326
codebase.checkout(commit=codebase._op.git_cli.head.commit) # move to detached head state
2427
with pytest.raises(ValueError) as exc_info:
2528
codebase.create_pr(title="test-create-pr title", body="test-create-pr body")
2629
assert "Cannot make a PR from a detached HEAD" in str(exc_info.value)
30+
31+
32+
def test_codebase_create_pr_active_branch_is_default_branch(codebase: Codebase):
33+
codebase.checkout(branch=codebase._op.default_branch)
34+
codebase.files[0].remove()
35+
codebase.commit()
36+
with pytest.raises(ValueError) as exc_info:
37+
codebase.create_pr(title="test-create-pr title", body="test-create-pr body")
38+
assert "Cannot make a PR from the default branch" in str(exc_info.value)

tests/unit/codegen/extensions/lsp/test_workspace_sync.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable
2+
13
import pytest
24
from lsprotocol.types import (
35
DidChangeTextDocumentParams,
@@ -16,6 +18,7 @@
1618
from pytest_lsp import LanguageClient
1719

1820
from codegen.sdk.core.codebase import Codebase
21+
from tests.unit.codegen.extensions.lsp.utils import apply_edit
1922

2023

2124
@pytest.fixture()
@@ -88,6 +91,7 @@ def example_function():
8891
""".strip(),
8992
),
9093
],
94+
ids=["example_function"],
9195
indirect=["document_uri", "original"],
9296
)
9397
async def test_did_change(
@@ -137,7 +141,7 @@ def example_function():
137141
pass
138142
""".strip(),
139143
},
140-
"file://{worskpaceFolder}test.py",
144+
"file://{worskpaceFolder}/test.py",
141145
),
142146
],
143147
)
@@ -165,11 +169,11 @@ async def test_did_close(
165169

166170
# Verify the document is removed from the workspace
167171
document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri))
168-
assert document.text == ""
172+
assert document.text == original["test.py"]
169173

170174

171175
@pytest.mark.parametrize(
172-
"original, document_uri, position, new_name, expected_text",
176+
"original, document_uri, position, new_name, expected",
173177
[
174178
(
175179
{
@@ -182,15 +186,17 @@ def main():
182186
""".strip(),
183187
},
184188
"file://{workspaceFolder}/test.py",
185-
Position(line=0, character=0), # Position of 'example_function'
189+
Position(line=0, character=5), # Position of 'example_function'
186190
"renamed_function",
187-
"""
191+
{
192+
"test.py": """
188193
def renamed_function():
189194
pass # modified
190195
191196
def main():
192197
renamed_function()
193198
""".strip(),
199+
},
194200
),
195201
],
196202
indirect=["document_uri", "original"],
@@ -202,7 +208,7 @@ async def test_rename_after_sync(
202208
document_uri: str,
203209
position: Position,
204210
new_name: str,
205-
expected_text: str,
211+
assert_expected: Callable,
206212
):
207213
# First open the document
208214
client.text_document_did_open(
@@ -243,8 +249,6 @@ async def test_rename_after_sync(
243249
new_name=new_name,
244250
)
245251
)
246-
247-
# Verify the rename was successful
248-
document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri))
249-
assert document is not None
250-
assert document.text == expected_text
252+
if result:
253+
apply_edit(codebase, result)
254+
assert_expected(codebase)

tests/unit/codegen/extensions/lsp/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def apply_edit(codebase: Codebase, edit: WorkspaceEdit):
1010
path = get_path(change.text_document.uri)
1111
file = codebase.get_file(str(path.relative_to(codebase.repo_path)))
1212
for edit in change.edits:
13-
print("BRUH")
1413
file.edit(edit.new_text)
1514
if isinstance(change, CreateFile):
1615
path = get_path(change.uri)

0 commit comments

Comments
 (0)