Skip to content

feat: Codegen-lsp v0 #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ types = [
"types-requests>=2.32.0.20241016",
"types-toml>=0.10.8.20240310",
]
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1"]
[tool.uv]
cache-keys = [{ git = { commit = true, tags = true } }]
dev-dependencies = [
Expand Down Expand Up @@ -149,11 +150,12 @@ dev-dependencies = [
"isort>=5.13.2",
"emoji>=2.14.0",
"pytest-benchmark[histogram]>=5.1.0",
"pytest-asyncio<1.0.0,>=0.21.1",
"pytest-asyncio>=0.21.1,<1.0.0",
"loguru>=0.7.3",
"httpx<0.28.2,>=0.28.1",
"jupyterlab>=4.3.5",
"modal>=0.73.25",
"pytest-lsp>=1.0.0b1",
]


Expand Down Expand Up @@ -212,6 +214,8 @@ xfail_strict = true
junit_duration_report = "call"
junit_logging = "all"
tmp_path_retention_policy = "failed"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
[build-system]
requires = ["hatchling>=1.26.3", "hatch-vcs>=0.4.0", "setuptools-scm>=8.0.0"]
build-backend = "hatchling.build"
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions src/codegen/extensions/lsp/definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging

from lsprotocol.types import Position

from codegen.sdk.core.assignment import Assignment
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute
from codegen.sdk.core.expressions.expression import Expression
from codegen.sdk.core.expressions.name import Name
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.interfaces.has_name import HasName

logger = logging.getLogger(__name__)


def go_to_definition(node: Editable | None, uri: str, position: Position) -> Editable | None:
if node is None or not isinstance(node, (Expression)):
logger.warning(f"No node found at {uri}:{position}")
return None
if isinstance(node, Name) and isinstance(node.parent, ChainedAttribute) and node.parent.attribute == node:
node = node.parent
if isinstance(node.parent, FunctionCall) and node.parent.get_name() == node:
node = node.parent
logger.info(f"Resolving definition for {node}")
if isinstance(node, FunctionCall):
resolved = node.function_definition
else:
resolved = node.resolved_value
if resolved is None:
logger.warning(f"No resolved value found for {node.name} at {uri}:{position}")
return None
if isinstance(resolved, (HasName,)):
resolved = resolved.get_name()
if isinstance(resolved.parent, Assignment) and resolved.parent.value == resolved:
resolved = resolved.parent.get_name()
return resolved
26 changes: 26 additions & 0 deletions src/codegen/extensions/lsp/document_symbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from lsprotocol.types import DocumentSymbol

from codegen.extensions.lsp.kind import get_kind
from codegen.extensions.lsp.range import get_range
from codegen.sdk.core.class_definition import Class
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.extensions.sort import sort_editables


def get_document_symbol(node: Editable) -> DocumentSymbol:
children = []
nodes = []
if isinstance(node, Class):
nodes.extend(node.methods)
nodes.extend(node.attributes)
nodes.extend(node.nested_classes)
nodes = sort_editables(nodes)
for child in nodes:
children.append(get_document_symbol(child))
return DocumentSymbol(
name=node.name,
kind=get_kind(node),
range=get_range(node),
selection_range=get_range(node.get_name()),
children=children,
)
71 changes: 71 additions & 0 deletions src/codegen/extensions/lsp/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
from pathlib import Path

from lsprotocol import types
from lsprotocol.types import Position, Range, TextEdit
from pygls.workspace import TextDocument, Workspace

from codegen.sdk.codebase.io.file_io import FileIO
from codegen.sdk.codebase.io.io import IO

logger = logging.getLogger(__name__)


class LSPIO(IO):
base_io: FileIO
workspace: Workspace
changes: dict[str, TextEdit] = {}

def __init__(self, workspace: Workspace):
self.workspace = workspace
self.base_io = FileIO()

def _get_doc(self, path: Path) -> TextDocument | None:
uri = path.as_uri()
logger.info(f"Getting document for {uri}")
return self.workspace.get_text_document(uri)

def read_bytes(self, path: Path) -> bytes:
if self.changes.get(path.as_uri()):
return self.changes[path.as_uri()].new_text.encode("utf-8")
if doc := self._get_doc(path):
return doc.source.encode("utf-8")
return self.base_io.read_bytes(path)

def write_bytes(self, path: Path, content: bytes) -> None:
logger.info(f"Writing bytes to {path}")
start = Position(line=0, character=0)
if doc := self._get_doc(path):
end = Position(line=len(doc.source), character=len(doc.source))
else:
end = Position(line=0, character=0)
self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))

def save_files(self, files: set[Path] | None = None) -> None:
self.base_io.save_files(files)

def check_changes(self) -> None:
self.base_io.check_changes()

def delete_file(self, path: Path) -> None:
self.base_io.delete_file(path)

def file_exists(self, path: Path) -> bool:
if doc := self._get_doc(path):
try:
doc.source
except FileNotFoundError:
return False
return True
return self.base_io.file_exists(path)

def untrack_file(self, path: Path) -> None:
self.base_io.untrack_file(path)

def get_document_changes(self) -> list[types.TextDocumentEdit]:
ret = []
for uri, change in self.changes.items():
id = types.OptionalVersionedTextDocumentIdentifier(uri=uri)
ret.append(types.TextDocumentEdit(text_document=id, edits=[change]))
self.changes = {}
return ret
31 changes: 31 additions & 0 deletions src/codegen/extensions/lsp/kind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from lsprotocol.types import SymbolKind

from codegen.sdk.core.assignment import Assignment
from codegen.sdk.core.class_definition import Class
from codegen.sdk.core.file import File
from codegen.sdk.core.function import Function
from codegen.sdk.core.interface import Interface
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.statements.attribute import Attribute
from codegen.sdk.typescript.namespace import TSNamespace

kinds = {
File: SymbolKind.File,
Class: SymbolKind.Class,
Function: SymbolKind.Function,
Assignment: SymbolKind.Variable,
Interface: SymbolKind.Interface,
TSNamespace: SymbolKind.Namespace,
Attribute: SymbolKind.Variable,
}


def get_kind(node: Editable) -> SymbolKind:
if isinstance(node, Function):
if node.is_method:
return SymbolKind.Method
for kind in kinds:
if isinstance(node, kind):
return kinds[kind]
msg = f"No kind found for {node}, {type(node)}"
raise ValueError(msg)
109 changes: 109 additions & 0 deletions src/codegen/extensions/lsp/lsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging

from lsprotocol import types

import codegen
from codegen.extensions.lsp.definition import go_to_definition
from codegen.extensions.lsp.document_symbol import get_document_symbol
from codegen.extensions.lsp.protocol import CodegenLanguageServerProtocol
from codegen.extensions.lsp.range import get_range
from codegen.extensions.lsp.server import CodegenLanguageServer
from codegen.extensions.lsp.utils import get_path
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
from codegen.sdk.core.file import SourceFile

version = getattr(codegen, "__version__", "v0.1")
server = CodegenLanguageServer("codegen", version, protocol_cls=CodegenLanguageServerProtocol)
logger = logging.getLogger(__name__)


@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
def did_open(server: CodegenLanguageServer, params: types.DidOpenTextDocumentParams) -> None:
"""Handle document open notification."""
logger.info(f"Document opened: {params.text_document.uri}")
# The document is automatically added to the workspace by pygls
# We can perform any additional processing here if needed
path = get_path(params.text_document.uri)
file = server.codebase.get_file(str(path), optional=True)
if not isinstance(file, SourceFile) and path.suffix in server.codebase.ctx.extensions:
sync = DiffLite(change_type=ChangeType.Added, path=path)
server.codebase.ctx.apply_diffs([sync])


@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
def did_change(server: CodegenLanguageServer, params: types.DidChangeTextDocumentParams) -> None:
"""Handle document change notification."""
logger.info(f"Document changed: {params.text_document.uri}")
# The document is automatically updated in the workspace by pygls
# We can perform any additional processing here if needed
path = get_path(params.text_document.uri)
sync = DiffLite(change_type=ChangeType.Modified, path=path)
server.codebase.ctx.apply_diffs([sync])


@server.feature(types.WORKSPACE_TEXT_DOCUMENT_CONTENT)
def workspace_text_document_content(server: CodegenLanguageServer, params: types.TextDocumentContentParams) -> types.TextDocumentContentResult:
"""Handle workspace text document content notification."""
logger.debug(f"Workspace text document content: {params.uri}")
path = get_path(params.uri)
if not server.io.file_exists(path):
logger.warning(f"File does not exist: {path}")
return types.TextDocumentContentResult(
text="",
)
content = server.io.read_text(path)
return types.TextDocumentContentResult(
text=content,
)


@server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentParams) -> None:
"""Handle document close notification."""
logger.info(f"Document closed: {params.text_document.uri}")
# The document is automatically removed from the workspace by pygls
# We can perform any additional cleanup here if needed


@server.feature(
types.TEXT_DOCUMENT_RENAME,
)
def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult:
symbol = server.get_symbol(params.text_document.uri, params.position)
if symbol is None:
logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}")
return
logger.info(f"Renaming symbol {symbol.name} to {params.new_name}")
symbol.rename(params.new_name)
server.codebase.commit()
return types.WorkspaceEdit(
document_changes=server.io.get_document_changes(),
)


@server.feature(
types.TEXT_DOCUMENT_DOCUMENT_SYMBOL,
)
def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult:
file = server.get_file(params.text_document.uri)
symbols = []
for symbol in file.symbols:
symbols.append(get_document_symbol(symbol))
return symbols


@server.feature(
types.TEXT_DOCUMENT_DEFINITION,
)
def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
node = server.get_node_under_cursor(params.text_document.uri, params.position)
resolved = go_to_definition(node, params.text_document.uri, params.position)
return types.Location(
uri=resolved.file.path.as_uri(),
range=get_range(resolved),
)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
server.start_io()
51 changes: 51 additions & 0 deletions src/codegen/extensions/lsp/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import threading
from pathlib import Path
from typing import TYPE_CHECKING

from lsprotocol.types import INITIALIZE, INITIALIZED, InitializedParams, InitializeParams, InitializeResult
from pygls.protocol import LanguageServerProtocol, lsp_method

from codegen.extensions.lsp.io import LSPIO
from codegen.extensions.lsp.utils import get_path
from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags
from codegen.sdk.core.codebase import Codebase

if TYPE_CHECKING:
from codegen.extensions.lsp.server import CodegenLanguageServer


class CodegenLanguageServerProtocol(LanguageServerProtocol):
_server: "CodegenLanguageServer"

def _init_codebase(self, params: InitializeParams) -> None:
if params.root_path:
root = Path(params.root_path)
elif params.root_uri:
root = get_path(params.root_uri)
else:
root = os.getcwd()
config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True))
io = LSPIO(self.workspace)
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io)
self._server.io = io

@lsp_method(INITIALIZE)
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
if params.root_path:
root = Path(params.root_path)
elif params.root_uri:
root = get_path(params.root_uri)
else:
root = os.getcwd()
config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True))
ret = super().lsp_initialize(params)

self._worker = threading.Thread(target=self._init_codebase, args=(params,))
self._worker.start()
return ret

@lsp_method(INITIALIZED)
def lsp_initialized(self, params: InitializedParams) -> None:
self._worker.join()
super().lsp_initialized(params)
32 changes: 32 additions & 0 deletions src/codegen/extensions/lsp/range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import tree_sitter
from lsprotocol.types import Position, Range
from pygls.workspace import TextDocument

from codegen.sdk.core.interfaces.editable import Editable


def get_range(node: Editable) -> Range:
start_point = node.start_point
end_point = node.end_point
for extended_node in node.extended_nodes:
if extended_node.start_point.row < start_point.row:
start_point = extended_node.start_point
if extended_node.end_point.row > end_point.row:
end_point = extended_node.end_point
return Range(
start=Position(line=start_point.row, character=start_point.column),
end=Position(line=end_point.row, character=end_point.column),
)


def get_tree_sitter_range(range: Range, document: TextDocument) -> tree_sitter.Range:
start_pos = tree_sitter.Point(row=range.start.line, column=range.start.character)
end_pos = tree_sitter.Point(row=range.end.line, column=range.end.character)
start_byte = document.offset_at_position(range.start)
end_byte = document.offset_at_position(range.end)
return tree_sitter.Range(
start_point=start_pos,
end_point=end_pos,
start_byte=start_byte,
end_byte=end_byte,
)
Loading