22
22
from codegen .shared .decorators .docs import apidoc , noapidoc
23
23
24
24
if TYPE_CHECKING :
25
- from collections .abc import Callable , Generator , Iterable
25
+ from collections .abc import Callable , Generator , Iterable , Sequence
26
26
27
27
import rich .repr
28
28
from rich .console import Console , ConsoleOptions , RenderResult
@@ -157,7 +157,7 @@ def __repr__(self) -> str:
157
157
def __rich_repr__ (self ) -> rich .repr .Result :
158
158
yield escape (self .filepath )
159
159
160
- __rich_repr__ .angular = ANGULAR_STYLE
160
+ __rich_repr__ .angular = ANGULAR_STYLE # type: ignore
161
161
162
162
def __rich_console__ (self , console : Console , options : ConsoleOptions ) -> RenderResult :
163
163
yield Pretty (self , max_string = MAX_STRING_LENGTH )
@@ -315,14 +315,14 @@ def extended_source(self, value: str) -> None:
315
315
@property
316
316
@reader
317
317
@noapidoc
318
- def children (self ) -> list [Editable ]:
318
+ def children (self ) -> list [Editable [ Self ] ]:
319
319
"""List of Editable instances that are children of this node."""
320
320
return [self ._parse_expression (child ) for child in self .ts_node .named_children ]
321
321
322
322
@property
323
323
@reader
324
324
@noapidoc
325
- def _anonymous_children (self ) -> list [Editable ]:
325
+ def _anonymous_children (self ) -> list [Editable [ Self ] ]:
326
326
"""All anonymous children of an editable."""
327
327
return [self ._parse_expression (child ) for child in self .ts_node .children if not child .is_named ]
328
328
@@ -343,28 +343,28 @@ def next_sibling(self) -> Editable | None:
343
343
@property
344
344
@reader
345
345
@noapidoc
346
- def next_named_sibling (self ) -> Editable | None :
346
+ def next_named_sibling (self ) -> Editable [ Parent ] | None :
347
347
if self .ts_node is None :
348
348
return None
349
349
350
350
next_named_sibling_node = self .ts_node .next_named_sibling
351
351
if next_named_sibling_node is None :
352
352
return None
353
353
354
- return self ._parse_expression (next_named_sibling_node )
354
+ return self .parent . _parse_expression (next_named_sibling_node )
355
355
356
356
@property
357
357
@reader
358
358
@noapidoc
359
- def previous_named_sibling (self ) -> Editable | None :
359
+ def previous_named_sibling (self ) -> Editable [ Parent ] | None :
360
360
if self .ts_node is None :
361
361
return None
362
362
363
363
previous_named_sibling_node = self .ts_node .prev_named_sibling
364
364
if previous_named_sibling_node is None :
365
365
return None
366
366
367
- return self ._parse_expression (previous_named_sibling_node )
367
+ return self .parent . _parse_expression (previous_named_sibling_node )
368
368
369
369
@property
370
370
def file (self ) -> SourceFile :
@@ -377,7 +377,7 @@ def file(self) -> SourceFile:
377
377
"""
378
378
if self ._file is None :
379
379
self ._file = self .G .get_node (self .file_node_id )
380
- return self ._file
380
+ return self ._file # type: ignore
381
381
382
382
@property
383
383
def filepath (self ) -> str :
@@ -391,7 +391,7 @@ def filepath(self) -> str:
391
391
return self .file .file_path
392
392
393
393
@reader
394
- def find_string_literals (self , strings_to_match : list [str ], fuzzy_match : bool = False ) -> list [Editable ]:
394
+ def find_string_literals (self , strings_to_match : list [str ], fuzzy_match : bool = False ) -> list [Editable [ Self ] ]:
395
395
"""Returns a list of string literals within this node's source that match any of the given
396
396
strings.
397
397
@@ -400,19 +400,20 @@ def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool =
400
400
fuzzy_match (bool): If True, matches substrings within string literals. If False, only matches exact strings. Defaults to False.
401
401
402
402
Returns:
403
- list[Editable]: A list of Editable objects representing the matching string literals.
403
+ list[Editable[Self] ]: A list of Editable objects representing the matching string literals.
404
404
"""
405
- matches = []
405
+ matches : list [ Editable [ Self ]] = []
406
406
for node in self .extended_nodes :
407
407
matches .extend (node ._find_string_literals (strings_to_match , fuzzy_match ))
408
408
return matches
409
409
410
410
@noapidoc
411
411
@reader
412
- def _find_string_literals (self , strings_to_match : list [str ], fuzzy_match : bool = False ) -> list [Editable ]:
412
+ def _find_string_literals (self , strings_to_match : list [str ], fuzzy_match : bool = False ) -> Sequence [Editable [ Self ] ]:
413
413
all_string_nodes = find_all_descendants (self .ts_node , type_names = {"string" })
414
414
editables = []
415
415
for string_node in all_string_nodes :
416
+ assert string_node .text is not None
416
417
full_string = string_node .text .strip (b'"' ).strip (b"'" )
417
418
if fuzzy_match :
418
419
if not any ([str_to_match .encode ("utf-8" ) in full_string for str_to_match in strings_to_match ]):
@@ -461,7 +462,7 @@ def _replace(self, old: str, new: str, count: int = -1, is_regex: bool = False,
461
462
if not is_regex :
462
463
old = re .escape (old )
463
464
464
- for match in re .finditer (old .encode ("utf-8" ), self .ts_node .text ):
465
+ for match in re .finditer (old .encode ("utf-8" ), self .ts_node .text ): # type: ignore
465
466
start_byte = self .ts_node .start_byte + match .start ()
466
467
end_byte = self .ts_node .start_byte + match .end ()
467
468
t = EditTransaction (
@@ -538,7 +539,7 @@ def _search(self, regex_pattern: str, include_strings: bool = True, include_comm
538
539
539
540
pattern = re .compile (regex_pattern .encode ("utf-8" ))
540
541
start_byte_offset = self .ts_node .byte_range [0 ]
541
- for match in pattern .finditer (string ):
542
+ for match in pattern .finditer (string ): # type: ignore
542
543
matching_byte_ranges .append ((match .start () + start_byte_offset , match .end () + start_byte_offset ))
543
544
544
545
matches : list [Editable ] = []
@@ -738,7 +739,7 @@ def should_keep(node: TSNode):
738
739
# Delete the node
739
740
t = RemoveTransaction (removed_start_byte , removed_end_byte , self .file , priority = priority , exec_func = exec_func )
740
741
if self .transaction_manager .add_transaction (t , dedupe = dedupe ):
741
- if exec_func :
742
+ if exec_func is not None :
742
743
self .parent ._removed_child ()
743
744
744
745
# If there are sibling nodes, delete the surrounding whitespace & formatting (commas)
@@ -873,11 +874,13 @@ def variable_usages(self) -> list[Editable]:
873
874
Editable corresponds to a TreeSitter node instance where the variable
874
875
is referenced.
875
876
"""
876
- usages = []
877
+ usages : Sequence [ Editable [ Self ]] = []
877
878
identifiers = get_all_identifiers (self .ts_node )
878
879
for identifier in identifiers :
879
880
# Excludes function names
880
881
parent = identifier .parent
882
+ if parent is None :
883
+ continue
881
884
if parent .type in ["call" , "call_expression" ]:
882
885
continue
883
886
# Excludes local import statements
@@ -899,7 +902,7 @@ def variable_usages(self) -> list[Editable]:
899
902
return usages
900
903
901
904
@reader
902
- def get_variable_usages (self , var_name : str , fuzzy_match : bool = False ) -> list [Editable ]:
905
+ def get_variable_usages (self , var_name : str , fuzzy_match : bool = False ) -> Sequence [Editable [ Self ] ]:
903
906
"""Returns Editables for all TreeSitter nodes corresponding to instances of variable usage
904
907
that matches the given variable name.
905
908
@@ -917,6 +920,12 @@ def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[
917
920
else :
918
921
return [usage for usage in self .variable_usages if var_name == usage .source ]
919
922
923
+ @overload
924
+ def _parse_expression (self , node : TSNode , ** kwargs ) -> Expression [Self ]: ...
925
+
926
+ @overload
927
+ def _parse_expression (self , node : TSNode | None , ** kwargs ) -> Expression [Self ] | None : ...
928
+
920
929
def _parse_expression (self , node : TSNode | None , ** kwargs ) -> Expression [Self ] | None :
921
930
return self .G .parser .parse_expression (node , self .file_node_id , self .G , self , ** kwargs )
922
931
0 commit comments