Skip to content

Commit bcaba53

Browse files
authored
LSP progress reporting (#456)
1 parent 7834af8 commit bcaba53

File tree

12 files changed

+254
-30
lines changed

12 files changed

+254
-30
lines changed

src/codegen/extensions/lsp/lsp.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,46 @@ def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentP
7171

7272
@server.feature(
7373
types.TEXT_DOCUMENT_RENAME,
74+
options=types.RenameOptions(work_done_progress=True),
7475
)
7576
def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult:
7677
symbol = server.get_symbol(params.text_document.uri, params.position)
7778
if symbol is None:
7879
logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}")
7980
return
8081
logger.info(f"Renaming symbol {symbol.name} to {params.new_name}")
82+
task = server.progress_manager.begin_with_token(f"Renaming symbol {symbol.name} to {params.new_name}", params.work_done_token)
8183
symbol.rename(params.new_name)
84+
task.update("Committing changes")
8285
server.codebase.commit()
86+
task.end()
8387
return server.io.get_workspace_edit()
8488

8589

8690
@server.feature(
8791
types.TEXT_DOCUMENT_DOCUMENT_SYMBOL,
92+
options=types.DocumentSymbolOptions(work_done_progress=True),
8893
)
8994
def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult:
9095
file = server.get_file(params.text_document.uri)
9196
symbols = []
92-
for symbol in file.symbols:
97+
task = server.progress_manager.begin_with_token(f"Getting document symbols for {params.text_document.uri}", params.work_done_token, count=len(file.symbols))
98+
for idx, symbol in enumerate(file.symbols):
99+
task.update(f"Getting document symbols for {params.text_document.uri}", count=idx)
93100
symbols.append(get_document_symbol(symbol))
101+
task.end()
94102
return symbols
95103

96104

97105
@server.feature(
98106
types.TEXT_DOCUMENT_DEFINITION,
107+
options=types.DefinitionOptions(work_done_progress=True),
99108
)
100109
def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
101110
node = server.get_node_under_cursor(params.text_document.uri, params.position)
111+
task = server.progress_manager.begin_with_token(f"Getting definition for {params.text_document.uri}", params.work_done_token)
102112
resolved = go_to_definition(node, params.text_document.uri, params.position)
113+
task.end()
103114
return types.Location(
104115
uri=resolved.file.path.as_uri(),
105116
range=get_range(resolved),
@@ -108,15 +119,11 @@ def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
108119

109120
@server.feature(
110121
types.TEXT_DOCUMENT_CODE_ACTION,
111-
options=types.CodeActionOptions(resolve_provider=True),
122+
options=types.CodeActionOptions(resolve_provider=True, work_done_progress=True),
112123
)
113124
def code_action(server: CodegenLanguageServer, params: types.CodeActionParams) -> types.CodeActionResult:
114125
logger.info(f"Received code action: {params}")
115-
if params.context.only:
116-
only = [types.CodeActionKind(kind) for kind in params.context.only]
117-
else:
118-
only = None
119-
actions = server.get_actions_for_range(params.text_document.uri, params.range, only)
126+
actions = server.get_actions_for_range(params)
120127
return actions
121128

122129

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import uuid
2+
3+
from lsprotocol import types
4+
from lsprotocol.types import ProgressToken
5+
from pygls.lsp.server import LanguageServer
6+
7+
from codegen.sdk.codebase.progress.progress import Progress
8+
from codegen.sdk.codebase.progress.stub_task import StubTask
9+
from codegen.sdk.codebase.progress.task import Task
10+
11+
12+
class LSPTask(Task):
13+
count: int | None
14+
15+
def __init__(self, server: LanguageServer, message: str, token: ProgressToken, count: int | None = None, create_token: bool = True) -> None:
16+
self.token = token
17+
if create_token:
18+
server.work_done_progress.begin(self.token, types.WorkDoneProgressBegin(title=message))
19+
self.server = server
20+
self.message = message
21+
self.count = count
22+
self.create_token = create_token
23+
24+
def update(self, message: str, count: int | None = None) -> None:
25+
if self.count is not None and count is not None:
26+
percent = int(count * 100 / self.count)
27+
else:
28+
percent = None
29+
self.server.work_done_progress.report(self.token, types.WorkDoneProgressReport(message=message, percentage=percent))
30+
31+
def end(self) -> None:
32+
if self.create_token:
33+
self.server.work_done_progress.end(self.token, value=types.WorkDoneProgressEnd())
34+
35+
36+
class LSPProgress(Progress[LSPTask | StubTask]):
37+
initialized = False
38+
39+
def __init__(self, server: LanguageServer, initial_token: ProgressToken | None = None):
40+
self.server = server
41+
self.initial_token = initial_token
42+
if initial_token is not None:
43+
self.server.work_done_progress.begin(initial_token, types.WorkDoneProgressBegin(title="Parsing codebase..."))
44+
45+
def begin_with_token(self, message: str, token: ProgressToken | None = None, *, count: int | None = None, create_token: bool = True) -> LSPTask | StubTask:
46+
if token is None:
47+
return StubTask()
48+
return LSPTask(self.server, message, token, count, create_token=create_token)
49+
50+
def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask:
51+
if self.initialized:
52+
token = str(uuid.uuid4())
53+
self.server.work_done_progress.create(token).result()
54+
return LSPTask(self.server, message, token, count, create_token=False)
55+
return self.begin_with_token(message, self.initial_token, count=None, create_token=False)
56+
57+
def finish_initialization(self) -> None:
58+
self.initialized = False # We can't initiate server work during syncs
59+
if self.initial_token is not None:
60+
self.server.work_done_progress.end(self.initial_token, value=types.WorkDoneProgressEnd())

src/codegen/extensions/lsp/protocol.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from pathlib import Path
33
from typing import TYPE_CHECKING
44

5-
from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult, WorkDoneProgressBegin, WorkDoneProgressEnd
5+
from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult
66
from pygls.protocol import LanguageServerProtocol, lsp_method
77

88
from codegen.extensions.lsp.io import LSPIO
9+
from codegen.extensions.lsp.progress import LSPProgress
910
from codegen.extensions.lsp.utils import get_path
1011
from codegen.sdk.codebase.config import CodebaseConfig
1112
from codegen.sdk.core.codebase import Codebase
@@ -19,6 +20,7 @@ class CodegenLanguageServerProtocol(LanguageServerProtocol):
1920
_server: "CodegenLanguageServer"
2021

2122
def _init_codebase(self, params: InitializeParams) -> None:
23+
progress = LSPProgress(self._server, params.work_done_token)
2224
if params.root_path:
2325
root = Path(params.root_path)
2426
elif params.root_uri:
@@ -27,15 +29,13 @@ def _init_codebase(self, params: InitializeParams) -> None:
2729
root = os.getcwd()
2830
config = CodebaseConfig(feature_flags=CodebaseFeatureFlags(full_range_index=True))
2931
io = LSPIO(self.workspace)
30-
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io)
32+
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io, progress=progress)
33+
self._server.progress_manager = progress
3134
self._server.io = io
32-
if params.work_done_token:
33-
self._server.work_done_progress.end(params.work_done_token, WorkDoneProgressEnd(message="Parsing codebase..."))
35+
progress.finish_initialization()
3436

3537
@lsp_method(INITIALIZE)
3638
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
3739
ret = super().lsp_initialize(params)
38-
if params.work_done_token:
39-
self._server.work_done_progress.begin(params.work_done_token, WorkDoneProgressBegin(title="Parsing codebase..."))
4040
self._init_codebase(params)
4141
return ret

src/codegen/extensions/lsp/server.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from collections.abc import Sequence
32
from typing import Any, Optional
43

54
from lsprotocol import types
@@ -8,8 +7,9 @@
87

98
from codegen.extensions.lsp.codemods import ACTIONS
109
from codegen.extensions.lsp.codemods.base import CodeAction
11-
from codegen.extensions.lsp.execute import execute_action, get_execute_action
10+
from codegen.extensions.lsp.execute import execute_action
1211
from codegen.extensions.lsp.io import LSPIO
12+
from codegen.extensions.lsp.progress import LSPProgress
1313
from codegen.extensions.lsp.range import get_tree_sitter_range
1414
from codegen.extensions.lsp.utils import get_path
1515
from codegen.sdk.core.codebase import Codebase
@@ -23,13 +23,14 @@
2323
class CodegenLanguageServer(LanguageServer):
2424
codebase: Optional[Codebase]
2525
io: Optional[LSPIO]
26+
progress_manager: Optional[LSPProgress]
2627
actions: dict[str, CodeAction]
2728

2829
def __init__(self, *args: Any, **kwargs: Any) -> None:
2930
super().__init__(*args, **kwargs)
3031
self.actions = {action.command_name(): action for action in ACTIONS}
31-
for action in self.actions.values():
32-
self.command(action.command_name())(get_execute_action(action))
32+
# for action in self.actions.values():
33+
# self.command(action.command_name())(get_execute_action(action))
3334

3435
def get_file(self, uri: str) -> SourceFile | File:
3536
path = get_path(uri)
@@ -68,19 +69,25 @@ def get_node_for_range(self, uri: str, range: Range) -> Editable | None:
6869
return node
6970
return None
7071

71-
def get_actions_for_range(self, uri: str, range: Range, only: Sequence[types.CodeActionKind] | None = None) -> list[types.CodeAction]:
72-
node = self.get_node_under_cursor(uri, range.start, range.end)
72+
def get_actions_for_range(self, params: types.CodeActionParams) -> list[types.CodeAction]:
73+
if params.context.only is not None:
74+
only = [types.CodeActionKind(kind) for kind in params.context.only]
75+
else:
76+
only = None
77+
node = self.get_node_under_cursor(params.text_document.uri, params.range.start)
7378
if node is None:
74-
logger.warning(f"No node found for range {range} in {uri}")
79+
logger.warning(f"No node found for range {params.range} in {params.text_document.uri}")
7580
return []
7681
actions = []
77-
for action in self.actions.values():
82+
task = self.progress_manager.begin_with_token(f"Getting code actions for {params.text_document.uri}", params.work_done_token, count=len(self.actions))
83+
for idx, action in enumerate(self.actions.values()):
84+
task.update(f"Checking action {action.name}", idx)
7885
if only and action.kind not in only:
7986
logger.warning(f"Skipping action {action.kind} because it is not in {only}")
8087
continue
8188
if action.is_applicable(self, node):
82-
actions.append(action.to_lsp(uri, range))
83-
89+
actions.append(action.to_lsp(params.text_document.uri, params.range))
90+
task.end()
8491
return actions
8592

8693
def resolve_action(self, action: types.CodeAction) -> types.CodeAction:

src/codegen/sdk/codebase/codebase_context.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
1717
from codegen.sdk.codebase.flagging.flags import Flags
1818
from codegen.sdk.codebase.io.file_io import FileIO
19+
from codegen.sdk.codebase.progress.stub_progress import StubProgress
1920
from codegen.sdk.codebase.transaction_manager import TransactionManager
2021
from codegen.sdk.codebase.validation import get_edges, post_reset_validation
2122
from codegen.sdk.core.autocommit import AutoCommit, commiter
@@ -39,6 +40,7 @@
3940
from codegen.git.repo_operator.repo_operator import RepoOperator
4041
from codegen.sdk.codebase.io.io import IO
4142
from codegen.sdk.codebase.node_classes.node_classes import NodeClasses
43+
from codegen.sdk.codebase.progress.progress import Progress
4244
from codegen.sdk.core.dataclasses.usage import Usage
4345
from codegen.sdk.core.expressions import Expression
4446
from codegen.sdk.core.external_module import ExternalModule
@@ -111,16 +113,19 @@ class CodebaseContext:
111113
projects: list[ProjectConfig]
112114
unapplied_diffs: list[DiffLite]
113115
io: IO
116+
progress: Progress
114117

115118
def __init__(
116119
self,
117120
projects: list[ProjectConfig],
118121
config: CodebaseConfig = DefaultConfig,
119122
io: IO | None = None,
123+
progress: Progress | None = None,
120124
) -> None:
121125
"""Initializes codebase graph and TransactionManager"""
122126
from codegen.sdk.core.parser import Parser
123127

128+
self.progress = progress or StubProgress()
124129
self._graph = PyDiGraph()
125130
self.filepath_idx = {}
126131
self._ext_module_idx = {}
@@ -371,7 +376,6 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr
371376
skip_uncache = incremental and ((len(files_to_sync[SyncType.DELETE]) + len(files_to_sync[SyncType.REPARSE])) == 0)
372377
if not skip_uncache:
373378
uncache_all()
374-
375379
# Step 0: Start the dependency manager and language engine if they exist
376380
# Start the dependency manager. This may or may not run asynchronously, depending on the implementation
377381
if self.dependency_manager is not None:
@@ -429,24 +433,29 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr
429433
file = self.get_file(file_path)
430434
file.remove_internal_edges()
431435

436+
task = self.progress.begin("Reparsing updated files", count=len(files_to_sync[SyncType.REPARSE]))
432437
files_to_resolve = []
433438
# Step 4: Reparse updated files
434-
for file_path in files_to_sync[SyncType.REPARSE]:
439+
for idx, file_path in enumerate(files_to_sync[SyncType.REPARSE]):
440+
task.update(f"Reparsing {self.to_relative(file_path)}", count=idx)
435441
file = self.get_file(file_path)
436442
to_resolve.extend(file.unparse(reparse=True))
437443
to_resolve = list(filter(lambda node: self.has_node(node.node_id) and node is not None, to_resolve))
438444
file.sync_with_file_content()
439445
files_to_resolve.append(file)
440-
446+
task.end()
441447
# Step 5: Add new files as nodes to graph (does not yet add edges)
442-
for filepath in files_to_sync[SyncType.ADD]:
448+
task = self.progress.begin("Adding new files", count=len(files_to_sync[SyncType.ADD]))
449+
for idx, filepath in enumerate(files_to_sync[SyncType.ADD]):
450+
task.update(f"Adding {self.to_relative(filepath)}", count=idx)
443451
content = self.io.read_text(filepath)
444452
# TODO: this is wrong with context changes
445453
if filepath.suffix in self.extensions:
446454
file_cls = self.node_classes.file_cls
447455
new_file = file_cls.from_content(filepath, content, self, sync=False, verify_syntax=False)
448456
if new_file is not None:
449457
files_to_resolve.append(new_file)
458+
task.end()
450459
for file in files_to_resolve:
451460
to_resolve.append(file)
452461
to_resolve.extend(file.get_nodes())
@@ -474,27 +483,35 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr
474483
self._computing = True
475484
try:
476485
logger.info(f"> Computing import resolution edges for {counter[NodeType.IMPORT]} imports")
486+
task = self.progress.begin("Resolving imports", count=counter[NodeType.IMPORT])
477487
for node in to_resolve:
478488
if node.node_type == NodeType.IMPORT:
489+
task.update(f"Resolving imports in {node.filepath}", count=idx)
479490
node._remove_internal_edges(EdgeType.IMPORT_SYMBOL_RESOLUTION)
480491
node.add_symbol_resolution_edge()
481492
to_resolve.extend(node.symbol_usages)
493+
task.end()
482494
if counter[NodeType.EXPORT] > 0:
483495
logger.info(f"> Computing export dependencies for {counter[NodeType.EXPORT]} exports")
496+
task = self.progress.begin("Computing export dependencies", count=counter[NodeType.EXPORT])
484497
for node in to_resolve:
485498
if node.node_type == NodeType.EXPORT:
499+
task.update(f"Computing export dependencies for {node.filepath}", count=idx)
486500
node._remove_internal_edges(EdgeType.EXPORT)
487501
node.compute_export_dependencies()
488502
to_resolve.extend(node.symbol_usages)
503+
task.end()
489504
if counter[NodeType.SYMBOL] > 0:
490505
from codegen.sdk.core.interfaces.inherits import Inherits
491506

492507
logger.info("> Computing superclass dependencies")
508+
task = self.progress.begin("Computing superclass dependencies", count=counter[NodeType.SYMBOL])
493509
for symbol in to_resolve:
494510
if isinstance(symbol, Inherits):
511+
task.update(f"Computing superclass dependencies for {symbol.filepath}", count=idx)
495512
symbol._remove_internal_edges(EdgeType.SUBCLASS)
496513
symbol.compute_superclass_dependencies()
497-
514+
task.end()
498515
if not skip_uncache:
499516
uncache_all()
500517
self._compute_dependencies(to_resolve, incremental)
@@ -504,17 +521,20 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr
504521
def _compute_dependencies(self, to_update: list[Importable], incremental: bool):
505522
seen = set()
506523
while to_update:
524+
task = self.progress.begin("Computing dependencies", count=len(to_update))
507525
step = to_update.copy()
508526
to_update.clear()
509527
logger.info(f"> Incrementally computing dependencies for {len(step)} nodes")
510-
for current in step:
528+
for idx, current in enumerate(step):
529+
task.update(f"Computing dependencies for {current.filepath}", count=idx)
511530
if current not in seen:
512531
seen.add(current)
513532
to_update.extend(current.recompute(incremental))
514533
if not incremental:
515534
for node in self._graph.nodes():
516535
if node not in seen:
517536
to_update.append(node)
537+
task.end()
518538
seen.clear()
519539

520540
def build_subgraph(self, nodes: list[NodeId]) -> PyDiGraph[Importable, Edge]:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, Generic, TypeVar
3+
4+
if TYPE_CHECKING:
5+
from codegen.sdk.codebase.progress.task import Task
6+
7+
T = TypeVar("T", bound="Task")
8+
9+
10+
class Progress(ABC, Generic[T]):
11+
@abstractmethod
12+
def begin(self, message: str, count: int | None = None) -> T:
13+
pass
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from codegen.sdk.codebase.progress.progress import Progress
2+
from codegen.sdk.codebase.progress.stub_task import StubTask
3+
4+
5+
class StubProgress(Progress[StubTask]):
6+
def begin(self, message: str, count: int | None = None) -> StubTask:
7+
return StubTask()

0 commit comments

Comments
 (0)