Skip to content

Split tests code action #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ types = [
"types-requests>=2.32.0.20241016",
"types-toml>=0.10.8.20240310",
]
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1"]
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1", "attrs>=25.1.0"]
[tool.uv]
cache-keys = [{ git = { commit = true, tags = true } }]
dev-dependencies = [
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/extensions/lsp/codemods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from codegen.extensions.lsp.codemods.base import CodeAction
from codegen.extensions.lsp.codemods.split_tests import SplitTests

ACTIONS: list[CodeAction] = [SplitTests()]
41 changes: 41 additions & 0 deletions src/codegen/extensions/lsp/codemods/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar

from lsprotocol import types

from codegen.sdk.core.interfaces.editable import Editable

if TYPE_CHECKING:
from codegen.extensions.lsp.server import CodegenLanguageServer


class CodeAction(ABC):
name: str
kind: ClassVar[types.CodeActionKind] = types.CodeActionKind.Refactor

def __init__(self):
pass

@abstractmethod
def execute(self, server: "CodegenLanguageServer", node: Editable) -> None: ...

@abstractmethod
def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool: ...

def to_command(self, uri: str, range: types.Range) -> types.Command:
return types.Command(
title=self.name,
command=self.command_name(),
arguments=[uri, range],
)

def to_lsp(self, uri: str, range: types.Range) -> types.CodeAction:
return types.CodeAction(
title=self.name,
kind=self.kind,
data=[self.command_name(), uri, range],
)

@classmethod
def command_name(cls) -> str:
return f"codegen-{cls.__name__}"
23 changes: 23 additions & 0 deletions src/codegen/extensions/lsp/codemods/move_symbol_to_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING

from codegen.extensions.lsp.codemods.base import CodeAction
from codegen.sdk.core.interfaces.editable import Editable

if TYPE_CHECKING:
from codegen.extensions.lsp.server import CodegenLanguageServer


class MoveSymbolToFile(CodeAction):
name = "Move Symbol to File"

def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool:
return True

def execute(self, server: "CodegenLanguageServer", node: Editable) -> None:
target_file = server.window_show_message_request(
"Select the file to move the symbol to",
server.codebase.files,
).result(timeout=10)
if target_file is None:
return
server.codebase.move_symbol(node.parent_symbol, target_file)
33 changes: 33 additions & 0 deletions src/codegen/extensions/lsp/codemods/split_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING

from codegen.extensions.lsp.codemods.base import CodeAction
from codegen.sdk.core.function import Function
from codegen.sdk.core.interfaces.editable import Editable

if TYPE_CHECKING:
from codegen.extensions.lsp.server import CodegenLanguageServer


class SplitTests(CodeAction):
name = "Split Tests"

def _get_targets(self, server: "CodegenLanguageServer", node: Editable) -> dict[Function, str]:
targets = {}
for function in node.file.functions:
if function.name.startswith("test_"):
target = f"{node.file.directory.path}/{function.name}.py"
if not server.codebase.has_file(target):
targets[function] = target
return targets

def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool:
if "tests" in str(node.file.path):
return len(self._get_targets(server, node)) > 1
return False

def execute(self, server: "CodegenLanguageServer", node: Editable) -> None:
targets = self._get_targets(server, node)
for function, target in targets.items():
new_file = server.codebase.create_file(target)
function.move_to_file(new_file, strategy="duplicate_dependencies")
# node.file.remove()
38 changes: 38 additions & 0 deletions src/codegen/extensions/lsp/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
from typing import TYPE_CHECKING, Any, Callable

from lsprotocol import types
from lsprotocol.types import Position, Range

from codegen.extensions.lsp.codemods.base import CodeAction

if TYPE_CHECKING:
from codegen.extensions.lsp.server import CodegenLanguageServer

logger = logging.getLogger(__name__)


def process_args(args: Any) -> tuple[str, Range]:
uri = args[0]
range = args[1]
range = Range(start=Position(line=range["start"]["line"], character=range["start"]["character"]), end=Position(line=range["end"]["line"], character=range["end"]["character"]))
return uri, range


def execute_action(server: "CodegenLanguageServer", action: CodeAction, args: Any) -> None:
uri, range = process_args(args)
node = server.get_node_under_cursor(uri, range.start, range.end)
if node is None:
logger.warning(f"No node found for range {range}")
return
action.execute(server, node, *args[2:])
server.codebase.commit()


def get_execute_action(action: CodeAction) -> Callable[["CodegenLanguageServer", Any], None]:
def execute_action(server: "CodegenLanguageServer", args: Any) -> None:
logger.info(f"Executing action {action.command_name()} with args {args}")
execute_action(server, action, args)
server.workspace_apply_edit(types.ApplyWorkspaceEditParams(edit=server.io.get_workspace_edit())).result()

return execute_action
117 changes: 91 additions & 26 deletions src/codegen/extensions/lsp/io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
import pprint
from dataclasses import dataclass
from pathlib import Path

from attr import asdict
from lsprotocol import types
from lsprotocol.types import Position, Range, TextEdit
from lsprotocol.types import CreateFile, CreateFileOptions, DeleteFile, Position, Range, RenameFile, TextEdit
from pygls.workspace import TextDocument, Workspace

from codegen.sdk.codebase.io.file_io import FileIO
Expand All @@ -11,61 +14,123 @@
logger = logging.getLogger(__name__)


@dataclass
class File:
doc: TextDocument
path: Path
change: TextEdit | None = None
other_change: CreateFile | RenameFile | DeleteFile | None = None
version: int = 0

@property
def deleted(self) -> bool:
return self.other_change is not None and self.other_change.kind == "delete"

@property
def created(self) -> bool:
return self.other_change is not None and self.other_change.kind == "create"

@property
def identifier(self) -> types.OptionalVersionedTextDocumentIdentifier:
return types.OptionalVersionedTextDocumentIdentifier(uri=self.path.as_uri(), version=self.version)


class LSPIO(IO):
base_io: FileIO
workspace: Workspace
changes: dict[str, TextEdit] = {}
files: dict[Path, File]

def __init__(self, workspace: Workspace):
self.workspace = workspace
self.base_io = FileIO()
self.files = {}

def _get_doc(self, path: Path) -> TextDocument | None:
def _get_doc(self, path: Path) -> TextDocument:
uri = path.as_uri()
logger.info(f"Getting document for {uri}")
return self.workspace.get_text_document(uri)

def _get_file(self, path: Path) -> File:
if path not in self.files:
doc = self._get_doc(path)
self.files[path] = File(doc=doc, path=path, version=doc.version or 0)
return self.files[path]

def read_text(self, path: Path) -> str:
file = self._get_file(path)
if file.deleted:
msg = f"File {path} has been deleted"
raise FileNotFoundError(msg)
if file.change:
return file.change.new_text
if file.created:
return ""
return file.doc.source

def read_bytes(self, path: Path) -> bytes:
if self.changes.get(path.as_uri()):
return self.changes[path.as_uri()].new_text.encode("utf-8")
if doc := self._get_doc(path):
return doc.source.encode("utf-8")
return self.base_io.read_bytes(path)
file = self._get_file(path)
if file.deleted:
msg = f"File {path} has been deleted"
raise FileNotFoundError(msg)
if file.change:
return file.change.new_text.encode("utf-8")
if file.created:
return b""
return file.doc.source.encode("utf-8")

def write_bytes(self, path: Path, content: bytes) -> None:
logger.info(f"Writing bytes to {path}")
start = Position(line=0, character=0)
if doc := self._get_doc(path):
end = Position(line=len(doc.source), character=len(doc.source))
file = self._get_file(path)
if self.file_exists(path):
lines = self.read_text(path).splitlines()
if len(lines) == 0:
end = Position(line=0, character=0)
else:
end = Position(line=len(lines) - 1, character=len(lines[-1]))
file.change = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))
else:
end = Position(line=0, character=0)
self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))
file.other_change = CreateFile(uri=path.as_uri(), options=CreateFileOptions())
file.change = TextEdit(range=Range(start=start, end=start), new_text=content.decode("utf-8"))

def save_files(self, files: set[Path] | None = None) -> None:
self.base_io.save_files(files)
logger.info(f"Saving files {files}")

def check_changes(self) -> None:
self.base_io.check_changes()

def delete_file(self, path: Path) -> None:
file = self._get_file(path)
file.other_change = DeleteFile(uri=path.as_uri())
self.base_io.delete_file(path)

def file_exists(self, path: Path) -> bool:
if doc := self._get_doc(path):
try:
doc.source
except FileNotFoundError:
return False
file = self._get_file(path)
if file.deleted:
return False
if file.change:
return True
if file.created:
return True
try:
file.doc.source
return True
return self.base_io.file_exists(path)
except FileNotFoundError:
return False

def untrack_file(self, path: Path) -> None:
self.base_io.untrack_file(path)

def get_document_changes(self) -> list[types.TextDocumentEdit]:
ret = []
for uri, change in self.changes.items():
id = types.OptionalVersionedTextDocumentIdentifier(uri=uri)
ret.append(types.TextDocumentEdit(text_document=id, edits=[change]))
self.changes = {}
return ret
def get_workspace_edit(self) -> types.WorkspaceEdit:
document_changes = []
for _, file in self.files.items():
id = file.identifier
if file.other_change:
document_changes.append(file.other_change)
file.other_change = None
if file.change:
document_changes.append(types.TextDocumentEdit(text_document=id, edits=[file.change]))

Check failure on line 132 in src/codegen/extensions/lsp/io.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "append" of "list" has incompatible type "TextDocumentEdit"; expected "CreateFile | RenameFile | DeleteFile" [arg-type]
file.version += 1
file.change = None
logger.info(f"Workspace edit: {pprint.pformat(list(map(asdict, document_changes)))}")
return types.WorkspaceEdit(document_changes=document_changes)
25 changes: 22 additions & 3 deletions src/codegen/extensions/lsp/lsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.R
logger.info(f"Renaming symbol {symbol.name} to {params.new_name}")
symbol.rename(params.new_name)
server.codebase.commit()
return types.WorkspaceEdit(
document_changes=server.io.get_document_changes(),
)
return server.io.get_workspace_edit()


@server.feature(
Expand All @@ -104,6 +102,27 @@ def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
)


@server.feature(
types.TEXT_DOCUMENT_CODE_ACTION,
options=types.CodeActionOptions(resolve_provider=True),
)
def code_action(server: CodegenLanguageServer, params: types.CodeActionParams) -> types.CodeActionResult:
logger.info(f"Received code action: {params}")
if params.context.only:
only = [types.CodeActionKind(kind) for kind in params.context.only]
else:
only = None
actions = server.get_actions_for_range(params.text_document.uri, params.range, only)
return actions


@server.feature(
types.CODE_ACTION_RESOLVE,
)
def code_action_resolve(server: CodegenLanguageServer, params: types.CodeAction) -> types.CodeAction:
return server.resolve_action(params)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
server.start_io()
23 changes: 6 additions & 17 deletions src/codegen/extensions/lsp/protocol.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import threading
from pathlib import Path
from typing import TYPE_CHECKING

from lsprotocol.types import INITIALIZE, INITIALIZED, InitializedParams, InitializeParams, InitializeResult
from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult, WorkDoneProgressBegin, WorkDoneProgressEnd
from pygls.protocol import LanguageServerProtocol, lsp_method

from codegen.extensions.lsp.io import LSPIO
Expand All @@ -30,23 +29,13 @@ def _init_codebase(self, params: InitializeParams) -> None:
io = LSPIO(self.workspace)
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io)
self._server.io = io
if params.work_done_token:
self._server.work_done_progress.end(params.work_done_token, WorkDoneProgressEnd(message="Parsing codebase..."))

@lsp_method(INITIALIZE)
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
if params.root_path:
root = Path(params.root_path)
elif params.root_uri:
root = get_path(params.root_uri)
else:
root = os.getcwd()
config = CodebaseConfig(feature_flags=CodebaseFeatureFlags(full_range_index=True))
ret = super().lsp_initialize(params)

self._worker = threading.Thread(target=self._init_codebase, args=(params,))
self._worker.start()
if params.work_done_token:
self._server.work_done_progress.begin(params.work_done_token, WorkDoneProgressBegin(title="Parsing codebase..."))
self._init_codebase(params)
return ret

@lsp_method(INITIALIZED)
def lsp_initialized(self, params: InitializedParams) -> None:
self._worker.join()
super().lsp_initialized(params)
Loading
Loading