Skip to content

Commit 4798d01

Browse files
authored
Split tests code action (#454)
1 parent c189cf2 commit 4798d01

File tree

18 files changed

+625
-225
lines changed

18 files changed

+625
-225
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ types = [
117117
"types-requests>=2.32.0.20241016",
118118
"types-toml>=0.10.8.20240310",
119119
]
120-
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1"]
120+
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1", "attrs>=25.1.0"]
121121
[tool.uv]
122122
cache-keys = [{ git = { commit = true, tags = true } }]
123123
dev-dependencies = [
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from codegen.extensions.lsp.codemods.base import CodeAction
2+
from codegen.extensions.lsp.codemods.split_tests import SplitTests
3+
4+
ACTIONS: list[CodeAction] = [SplitTests()]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, ClassVar
3+
4+
from lsprotocol import types
5+
6+
from codegen.sdk.core.interfaces.editable import Editable
7+
8+
if TYPE_CHECKING:
9+
from codegen.extensions.lsp.server import CodegenLanguageServer
10+
11+
12+
class CodeAction(ABC):
13+
name: str
14+
kind: ClassVar[types.CodeActionKind] = types.CodeActionKind.Refactor
15+
16+
def __init__(self):
17+
pass
18+
19+
@abstractmethod
20+
def execute(self, server: "CodegenLanguageServer", node: Editable) -> None: ...
21+
22+
@abstractmethod
23+
def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool: ...
24+
25+
def to_command(self, uri: str, range: types.Range) -> types.Command:
26+
return types.Command(
27+
title=self.name,
28+
command=self.command_name(),
29+
arguments=[uri, range],
30+
)
31+
32+
def to_lsp(self, uri: str, range: types.Range) -> types.CodeAction:
33+
return types.CodeAction(
34+
title=self.name,
35+
kind=self.kind,
36+
data=[self.command_name(), uri, range],
37+
)
38+
39+
@classmethod
40+
def command_name(cls) -> str:
41+
return f"codegen-{cls.__name__}"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import TYPE_CHECKING
2+
3+
from codegen.extensions.lsp.codemods.base import CodeAction
4+
from codegen.sdk.core.interfaces.editable import Editable
5+
6+
if TYPE_CHECKING:
7+
from codegen.extensions.lsp.server import CodegenLanguageServer
8+
9+
10+
class MoveSymbolToFile(CodeAction):
11+
name = "Move Symbol to File"
12+
13+
def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool:
14+
return True
15+
16+
def execute(self, server: "CodegenLanguageServer", node: Editable) -> None:
17+
target_file = server.window_show_message_request(
18+
"Select the file to move the symbol to",
19+
server.codebase.files,
20+
).result(timeout=10)
21+
if target_file is None:
22+
return
23+
server.codebase.move_symbol(node.parent_symbol, target_file)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import TYPE_CHECKING
2+
3+
from codegen.extensions.lsp.codemods.base import CodeAction
4+
from codegen.sdk.core.function import Function
5+
from codegen.sdk.core.interfaces.editable import Editable
6+
7+
if TYPE_CHECKING:
8+
from codegen.extensions.lsp.server import CodegenLanguageServer
9+
10+
11+
class SplitTests(CodeAction):
12+
name = "Split Tests"
13+
14+
def _get_targets(self, server: "CodegenLanguageServer", node: Editable) -> dict[Function, str]:
15+
targets = {}
16+
for function in node.file.functions:
17+
if function.name.startswith("test_"):
18+
target = f"{node.file.directory.path}/{function.name}.py"
19+
if not server.codebase.has_file(target):
20+
targets[function] = target
21+
return targets
22+
23+
def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool:
24+
if "tests" in str(node.file.path):
25+
return len(self._get_targets(server, node)) > 1
26+
return False
27+
28+
def execute(self, server: "CodegenLanguageServer", node: Editable) -> None:
29+
targets = self._get_targets(server, node)
30+
for function, target in targets.items():
31+
new_file = server.codebase.create_file(target)
32+
function.move_to_file(new_file, strategy="duplicate_dependencies")
33+
# node.file.remove()

src/codegen/extensions/lsp/execute.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
from typing import TYPE_CHECKING, Any, Callable
3+
4+
from lsprotocol import types
5+
from lsprotocol.types import Position, Range
6+
7+
from codegen.extensions.lsp.codemods.base import CodeAction
8+
9+
if TYPE_CHECKING:
10+
from codegen.extensions.lsp.server import CodegenLanguageServer
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def process_args(args: Any) -> tuple[str, Range]:
16+
uri = args[0]
17+
range = args[1]
18+
range = Range(start=Position(line=range["start"]["line"], character=range["start"]["character"]), end=Position(line=range["end"]["line"], character=range["end"]["character"]))
19+
return uri, range
20+
21+
22+
def execute_action(server: "CodegenLanguageServer", action: CodeAction, args: Any) -> None:
23+
uri, range = process_args(args)
24+
node = server.get_node_under_cursor(uri, range.start, range.end)
25+
if node is None:
26+
logger.warning(f"No node found for range {range}")
27+
return
28+
action.execute(server, node, *args[2:])
29+
server.codebase.commit()
30+
31+
32+
def get_execute_action(action: CodeAction) -> Callable[["CodegenLanguageServer", Any], None]:
33+
def execute_action(server: "CodegenLanguageServer", args: Any) -> None:
34+
logger.info(f"Executing action {action.command_name()} with args {args}")
35+
execute_action(server, action, args)
36+
server.workspace_apply_edit(types.ApplyWorkspaceEditParams(edit=server.io.get_workspace_edit())).result()
37+
38+
return execute_action

src/codegen/extensions/lsp/io.py

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import logging
2+
import pprint
3+
from dataclasses import dataclass
24
from pathlib import Path
35

6+
from attr import asdict
47
from lsprotocol import types
5-
from lsprotocol.types import Position, Range, TextEdit
8+
from lsprotocol.types import CreateFile, CreateFileOptions, DeleteFile, Position, Range, RenameFile, TextEdit
69
from pygls.workspace import TextDocument, Workspace
710

811
from codegen.sdk.codebase.io.file_io import FileIO
@@ -11,61 +14,123 @@
1114
logger = logging.getLogger(__name__)
1215

1316

17+
@dataclass
18+
class File:
19+
doc: TextDocument
20+
path: Path
21+
change: TextEdit | None = None
22+
other_change: CreateFile | RenameFile | DeleteFile | None = None
23+
version: int = 0
24+
25+
@property
26+
def deleted(self) -> bool:
27+
return self.other_change is not None and self.other_change.kind == "delete"
28+
29+
@property
30+
def created(self) -> bool:
31+
return self.other_change is not None and self.other_change.kind == "create"
32+
33+
@property
34+
def identifier(self) -> types.OptionalVersionedTextDocumentIdentifier:
35+
return types.OptionalVersionedTextDocumentIdentifier(uri=self.path.as_uri(), version=self.version)
36+
37+
1438
class LSPIO(IO):
1539
base_io: FileIO
1640
workspace: Workspace
17-
changes: dict[str, TextEdit] = {}
41+
files: dict[Path, File]
1842

1943
def __init__(self, workspace: Workspace):
2044
self.workspace = workspace
2145
self.base_io = FileIO()
46+
self.files = {}
2247

23-
def _get_doc(self, path: Path) -> TextDocument | None:
48+
def _get_doc(self, path: Path) -> TextDocument:
2449
uri = path.as_uri()
2550
logger.info(f"Getting document for {uri}")
2651
return self.workspace.get_text_document(uri)
2752

53+
def _get_file(self, path: Path) -> File:
54+
if path not in self.files:
55+
doc = self._get_doc(path)
56+
self.files[path] = File(doc=doc, path=path, version=doc.version or 0)
57+
return self.files[path]
58+
59+
def read_text(self, path: Path) -> str:
60+
file = self._get_file(path)
61+
if file.deleted:
62+
msg = f"File {path} has been deleted"
63+
raise FileNotFoundError(msg)
64+
if file.change:
65+
return file.change.new_text
66+
if file.created:
67+
return ""
68+
return file.doc.source
69+
2870
def read_bytes(self, path: Path) -> bytes:
29-
if self.changes.get(path.as_uri()):
30-
return self.changes[path.as_uri()].new_text.encode("utf-8")
31-
if doc := self._get_doc(path):
32-
return doc.source.encode("utf-8")
33-
return self.base_io.read_bytes(path)
71+
file = self._get_file(path)
72+
if file.deleted:
73+
msg = f"File {path} has been deleted"
74+
raise FileNotFoundError(msg)
75+
if file.change:
76+
return file.change.new_text.encode("utf-8")
77+
if file.created:
78+
return b""
79+
return file.doc.source.encode("utf-8")
3480

3581
def write_bytes(self, path: Path, content: bytes) -> None:
3682
logger.info(f"Writing bytes to {path}")
3783
start = Position(line=0, character=0)
38-
if doc := self._get_doc(path):
39-
end = Position(line=len(doc.source), character=len(doc.source))
84+
file = self._get_file(path)
85+
if self.file_exists(path):
86+
lines = self.read_text(path).splitlines()
87+
if len(lines) == 0:
88+
end = Position(line=0, character=0)
89+
else:
90+
end = Position(line=len(lines) - 1, character=len(lines[-1]))
91+
file.change = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))
4092
else:
41-
end = Position(line=0, character=0)
42-
self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))
93+
file.other_change = CreateFile(uri=path.as_uri(), options=CreateFileOptions())
94+
file.change = TextEdit(range=Range(start=start, end=start), new_text=content.decode("utf-8"))
4395

4496
def save_files(self, files: set[Path] | None = None) -> None:
45-
self.base_io.save_files(files)
97+
logger.info(f"Saving files {files}")
4698

4799
def check_changes(self) -> None:
48100
self.base_io.check_changes()
49101

50102
def delete_file(self, path: Path) -> None:
103+
file = self._get_file(path)
104+
file.other_change = DeleteFile(uri=path.as_uri())
51105
self.base_io.delete_file(path)
52106

53107
def file_exists(self, path: Path) -> bool:
54-
if doc := self._get_doc(path):
55-
try:
56-
doc.source
57-
except FileNotFoundError:
58-
return False
108+
file = self._get_file(path)
109+
if file.deleted:
110+
return False
111+
if file.change:
112+
return True
113+
if file.created:
114+
return True
115+
try:
116+
file.doc.source
59117
return True
60-
return self.base_io.file_exists(path)
118+
except FileNotFoundError:
119+
return False
61120

62121
def untrack_file(self, path: Path) -> None:
63122
self.base_io.untrack_file(path)
64123

65-
def get_document_changes(self) -> list[types.TextDocumentEdit]:
66-
ret = []
67-
for uri, change in self.changes.items():
68-
id = types.OptionalVersionedTextDocumentIdentifier(uri=uri)
69-
ret.append(types.TextDocumentEdit(text_document=id, edits=[change]))
70-
self.changes = {}
71-
return ret
124+
def get_workspace_edit(self) -> types.WorkspaceEdit:
125+
document_changes = []
126+
for _, file in self.files.items():
127+
id = file.identifier
128+
if file.other_change:
129+
document_changes.append(file.other_change)
130+
file.other_change = None
131+
if file.change:
132+
document_changes.append(types.TextDocumentEdit(text_document=id, edits=[file.change]))
133+
file.version += 1
134+
file.change = None
135+
logger.info(f"Workspace edit: {pprint.pformat(list(map(asdict, document_changes)))}")
136+
return types.WorkspaceEdit(document_changes=document_changes)

src/codegen/extensions/lsp/lsp.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.R
7676
logger.info(f"Renaming symbol {symbol.name} to {params.new_name}")
7777
symbol.rename(params.new_name)
7878
server.codebase.commit()
79-
return types.WorkspaceEdit(
80-
document_changes=server.io.get_document_changes(),
81-
)
79+
return server.io.get_workspace_edit()
8280

8381

8482
@server.feature(
@@ -104,6 +102,27 @@ def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
104102
)
105103

106104

105+
@server.feature(
106+
types.TEXT_DOCUMENT_CODE_ACTION,
107+
options=types.CodeActionOptions(resolve_provider=True),
108+
)
109+
def code_action(server: CodegenLanguageServer, params: types.CodeActionParams) -> types.CodeActionResult:
110+
logger.info(f"Received code action: {params}")
111+
if params.context.only:
112+
only = [types.CodeActionKind(kind) for kind in params.context.only]
113+
else:
114+
only = None
115+
actions = server.get_actions_for_range(params.text_document.uri, params.range, only)
116+
return actions
117+
118+
119+
@server.feature(
120+
types.CODE_ACTION_RESOLVE,
121+
)
122+
def code_action_resolve(server: CodegenLanguageServer, params: types.CodeAction) -> types.CodeAction:
123+
return server.resolve_action(params)
124+
125+
107126
if __name__ == "__main__":
108127
logging.basicConfig(level=logging.INFO)
109128
server.start_io()

src/codegen/extensions/lsp/protocol.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import os
2-
import threading
32
from pathlib import Path
43
from typing import TYPE_CHECKING
54

6-
from lsprotocol.types import INITIALIZE, INITIALIZED, InitializedParams, InitializeParams, InitializeResult
5+
from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult, WorkDoneProgressBegin, WorkDoneProgressEnd
76
from pygls.protocol import LanguageServerProtocol, lsp_method
87

98
from codegen.extensions.lsp.io import LSPIO
@@ -30,23 +29,13 @@ def _init_codebase(self, params: InitializeParams) -> None:
3029
io = LSPIO(self.workspace)
3130
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io)
3231
self._server.io = io
32+
if params.work_done_token:
33+
self._server.work_done_progress.end(params.work_done_token, WorkDoneProgressEnd(message="Parsing codebase..."))
3334

3435
@lsp_method(INITIALIZE)
3536
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
36-
if params.root_path:
37-
root = Path(params.root_path)
38-
elif params.root_uri:
39-
root = get_path(params.root_uri)
40-
else:
41-
root = os.getcwd()
42-
config = CodebaseConfig(feature_flags=CodebaseFeatureFlags(full_range_index=True))
4337
ret = super().lsp_initialize(params)
44-
45-
self._worker = threading.Thread(target=self._init_codebase, args=(params,))
46-
self._worker.start()
38+
if params.work_done_token:
39+
self._server.work_done_progress.begin(params.work_done_token, WorkDoneProgressBegin(title="Parsing codebase..."))
40+
self._init_codebase(params)
4741
return ret
48-
49-
@lsp_method(INITIALIZED)
50-
def lsp_initialized(self, params: InitializedParams) -> None:
51-
self._worker.join()
52-
super().lsp_initialized(params)

0 commit comments

Comments
 (0)