Skip to content

Commit 3e70422

Browse files
authored
Mypyc/cython changes (#318)
1 parent dc31440 commit 3e70422

File tree

11 files changed

+83
-32
lines changed

11 files changed

+83
-32
lines changed

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ keywords = [
102102
codegen = "codegen.cli.cli:main"
103103

104104
[project.optional-dependencies]
105-
types = ["types-networkx>=3.2.1.20240918", "types-tabulate>=0.9.0.20240106"]
105+
types = [
106+
"types-networkx>=3.2.1.20240918",
107+
"types-tabulate>=0.9.0.20240106",
108+
"types-requests>=2.32.0.20241016",
109+
"types-toml>=0.10.8.20240310",
110+
]
106111
[tool.uv]
107112
cache-keys = [{ git = { commit = true, tags = true } }]
108113
dev-dependencies = [
@@ -199,6 +204,7 @@ tmp_path_retention_policy = "failed"
199204
requires = ["hatchling>=1.26.3", "hatch-vcs>=0.4.0", "setuptools-scm>=8.0.0"]
200205
build-backend = "hatchling.build"
201206

207+
202208
[tool.deptry]
203209
extend_exclude = [".*/eval/test_files/.*.py", ".*conftest.py"]
204210
pep621_dev_dependency_groups = ["types"]

src/codegen/py.typed

Whitespace-only changes.

src/codegen/sdk/codebase/flagging/code_flag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING
2+
from typing import Generic, TypeVar
33

44
from codegen.sdk.codebase.flagging.enums import MessageType
5+
from codegen.sdk.core.interfaces.editable import Editable
56

6-
if TYPE_CHECKING:
7-
from codegen.sdk.core.interfaces.editable import Editable
7+
Symbol = TypeVar("Symbol", bound=Editable | None)
88

99

1010
@dataclass
11-
class CodeFlag[Symbol: Editable | None]:
11+
class CodeFlag(Generic[Symbol]):
1212
symbol: Symbol
1313
message: str | None = None # a short desc of the code flag/violation. ex: enums should be ordered alphabetically
1414
message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN # where to send the message (either Github or Slack)

src/codegen/sdk/codebase/flagging/flags.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from dataclasses import dataclass, field
2+
from typing import TypeVar
23

34
from codegen.sdk.codebase.flagging.code_flag import CodeFlag
45
from codegen.sdk.codebase.flagging.enums import MessageType
56
from codegen.sdk.codebase.flagging.group import Group
67
from codegen.sdk.core.interfaces.editable import Editable
78
from codegen.shared.decorators.docs import noapidoc
89

10+
Symbol = TypeVar("Symbol", bound=Editable)
11+
912

1013
@dataclass
1114
class Flags:
1215
_flags: list[CodeFlag] = field(default_factory=list)
1316
_find_mode: bool = False
1417
_active_group: list[CodeFlag] | None = None
1518

16-
def flag_instance[Symbol: Editable | None](
19+
def flag_instance(
1720
self,
18-
symbol: Symbol = None,
21+
symbol: Symbol | None = None,
1922
message: str | None = None,
2023
message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN,
2124
message_recipient: str | None = None,

src/codegen/sdk/codebase/multigraph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from collections import defaultdict
22
from dataclasses import dataclass, field
3+
from typing import Generic, TypeVar
34

45
from codegen.sdk import TYPE_CHECKING
56
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
67

78
if TYPE_CHECKING:
89
from codegen.sdk.core.function import Function
910

11+
TFunction = TypeVar("TFunction", bound=Function)
12+
1013

1114
@dataclass
12-
class MultiGraph[TFunction: Function]:
15+
class MultiGraph(Generic[TFunction]):
1316
"""Mapping of API endpoints to their definitions and usages across languages."""
1417

1518
api_definitions: dict[str, TFunction] = field(default_factory=dict)

src/codegen/sdk/core/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def invalidate(self):
587587

588588
@classmethod
589589
@noapidoc
590-
def from_content(cls, filepath: str, content: str, G: CodebaseGraph, sync: bool = True, verify_syntax: bool = True) -> Self | None:
590+
def from_content(cls, filepath: str | PathLike | Path, content: str, G: CodebaseGraph, sync: bool = True, verify_syntax: bool = True) -> Self | None:
591591
"""Creates a new file from content and adds it to the graph."""
592592
path = G.to_absolute(filepath)
593593
ts_node = parse_file(path, content)
@@ -605,7 +605,7 @@ def from_content(cls, filepath: str, content: str, G: CodebaseGraph, sync: bool
605605
G.add_single_file(path)
606606
return G.get_file(filepath)
607607
else:
608-
return cls(ts_node, filepath, G)
608+
return cls(ts_node, Path(filepath), G)
609609

610610
@classmethod
611611
@noapidoc

src/codegen/sdk/core/interfaces/editable.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from codegen.shared.decorators.docs import apidoc, noapidoc
2323

2424
if TYPE_CHECKING:
25-
from collections.abc import Callable, Generator, Iterable
25+
from collections.abc import Callable, Generator, Iterable, Sequence
2626

2727
import rich.repr
2828
from rich.console import Console, ConsoleOptions, RenderResult
@@ -157,7 +157,7 @@ def __repr__(self) -> str:
157157
def __rich_repr__(self) -> rich.repr.Result:
158158
yield escape(self.filepath)
159159

160-
__rich_repr__.angular = ANGULAR_STYLE
160+
__rich_repr__.angular = ANGULAR_STYLE # type: ignore
161161

162162
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
163163
yield Pretty(self, max_string=MAX_STRING_LENGTH)
@@ -315,14 +315,14 @@ def extended_source(self, value: str) -> None:
315315
@property
316316
@reader
317317
@noapidoc
318-
def children(self) -> list[Editable]:
318+
def children(self) -> list[Editable[Self]]:
319319
"""List of Editable instances that are children of this node."""
320320
return [self._parse_expression(child) for child in self.ts_node.named_children]
321321

322322
@property
323323
@reader
324324
@noapidoc
325-
def _anonymous_children(self) -> list[Editable]:
325+
def _anonymous_children(self) -> list[Editable[Self]]:
326326
"""All anonymous children of an editable."""
327327
return [self._parse_expression(child) for child in self.ts_node.children if not child.is_named]
328328

@@ -343,28 +343,28 @@ def next_sibling(self) -> Editable | None:
343343
@property
344344
@reader
345345
@noapidoc
346-
def next_named_sibling(self) -> Editable | None:
346+
def next_named_sibling(self) -> Editable[Parent] | None:
347347
if self.ts_node is None:
348348
return None
349349

350350
next_named_sibling_node = self.ts_node.next_named_sibling
351351
if next_named_sibling_node is None:
352352
return None
353353

354-
return self._parse_expression(next_named_sibling_node)
354+
return self.parent._parse_expression(next_named_sibling_node)
355355

356356
@property
357357
@reader
358358
@noapidoc
359-
def previous_named_sibling(self) -> Editable | None:
359+
def previous_named_sibling(self) -> Editable[Parent] | None:
360360
if self.ts_node is None:
361361
return None
362362

363363
previous_named_sibling_node = self.ts_node.prev_named_sibling
364364
if previous_named_sibling_node is None:
365365
return None
366366

367-
return self._parse_expression(previous_named_sibling_node)
367+
return self.parent._parse_expression(previous_named_sibling_node)
368368

369369
@property
370370
def file(self) -> SourceFile:
@@ -377,7 +377,7 @@ def file(self) -> SourceFile:
377377
"""
378378
if self._file is None:
379379
self._file = self.G.get_node(self.file_node_id)
380-
return self._file
380+
return self._file # type: ignore
381381

382382
@property
383383
def filepath(self) -> str:
@@ -391,7 +391,7 @@ def filepath(self) -> str:
391391
return self.file.file_path
392392

393393
@reader
394-
def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable]:
394+
def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable[Self]]:
395395
"""Returns a list of string literals within this node's source that match any of the given
396396
strings.
397397
@@ -400,19 +400,20 @@ def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool =
400400
fuzzy_match (bool): If True, matches substrings within string literals. If False, only matches exact strings. Defaults to False.
401401
402402
Returns:
403-
list[Editable]: A list of Editable objects representing the matching string literals.
403+
list[Editable[Self]]: A list of Editable objects representing the matching string literals.
404404
"""
405-
matches = []
405+
matches: list[Editable[Self]] = []
406406
for node in self.extended_nodes:
407407
matches.extend(node._find_string_literals(strings_to_match, fuzzy_match))
408408
return matches
409409

410410
@noapidoc
411411
@reader
412-
def _find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable]:
412+
def _find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> Sequence[Editable[Self]]:
413413
all_string_nodes = find_all_descendants(self.ts_node, type_names={"string"})
414414
editables = []
415415
for string_node in all_string_nodes:
416+
assert string_node.text is not None
416417
full_string = string_node.text.strip(b'"').strip(b"'")
417418
if fuzzy_match:
418419
if not any([str_to_match.encode("utf-8") in full_string for str_to_match in strings_to_match]):
@@ -461,7 +462,7 @@ def _replace(self, old: str, new: str, count: int = -1, is_regex: bool = False,
461462
if not is_regex:
462463
old = re.escape(old)
463464

464-
for match in re.finditer(old.encode("utf-8"), self.ts_node.text):
465+
for match in re.finditer(old.encode("utf-8"), self.ts_node.text): # type: ignore
465466
start_byte = self.ts_node.start_byte + match.start()
466467
end_byte = self.ts_node.start_byte + match.end()
467468
t = EditTransaction(
@@ -538,7 +539,7 @@ def _search(self, regex_pattern: str, include_strings: bool = True, include_comm
538539

539540
pattern = re.compile(regex_pattern.encode("utf-8"))
540541
start_byte_offset = self.ts_node.byte_range[0]
541-
for match in pattern.finditer(string):
542+
for match in pattern.finditer(string): # type: ignore
542543
matching_byte_ranges.append((match.start() + start_byte_offset, match.end() + start_byte_offset))
543544

544545
matches: list[Editable] = []
@@ -738,7 +739,7 @@ def should_keep(node: TSNode):
738739
# Delete the node
739740
t = RemoveTransaction(removed_start_byte, removed_end_byte, self.file, priority=priority, exec_func=exec_func)
740741
if self.transaction_manager.add_transaction(t, dedupe=dedupe):
741-
if exec_func:
742+
if exec_func is not None:
742743
self.parent._removed_child()
743744

744745
# If there are sibling nodes, delete the surrounding whitespace & formatting (commas)
@@ -873,11 +874,13 @@ def variable_usages(self) -> list[Editable]:
873874
Editable corresponds to a TreeSitter node instance where the variable
874875
is referenced.
875876
"""
876-
usages = []
877+
usages: Sequence[Editable[Self]] = []
877878
identifiers = get_all_identifiers(self.ts_node)
878879
for identifier in identifiers:
879880
# Excludes function names
880881
parent = identifier.parent
882+
if parent is None:
883+
continue
881884
if parent.type in ["call", "call_expression"]:
882885
continue
883886
# Excludes local import statements
@@ -899,7 +902,7 @@ def variable_usages(self) -> list[Editable]:
899902
return usages
900903

901904
@reader
902-
def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[Editable]:
905+
def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> Sequence[Editable[Self]]:
903906
"""Returns Editables for all TreeSitter nodes corresponding to instances of variable usage
904907
that matches the given variable name.
905908
@@ -917,6 +920,12 @@ def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[
917920
else:
918921
return [usage for usage in self.variable_usages if var_name == usage.source]
919922

923+
@overload
924+
def _parse_expression(self, node: TSNode, **kwargs) -> Expression[Self]: ...
925+
926+
@overload
927+
def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None: ...
928+
920929
def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None:
921930
return self.G.parser.parse_expression(node, self.file_node_id, self.G, self, **kwargs)
922931

src/codegen/sdk/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
type JSON = dict[str, JSON] | list[JSON] | str | int | float | bool | None
1+
from typing import TypeAlias
2+
3+
JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None

src/codegen/sdk/typescript/statements/switch_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from codegen.sdk.codebase.codebase_graph import CodebaseGraph
13-
from src.codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement
13+
from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement
1414

1515

1616
@ts_apidoc

tests/integration/codemod/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import shutil
33
from collections.abc import Generator
44
from pathlib import Path
5+
from typing import TYPE_CHECKING
56
from unittest.mock import MagicMock
67

78
import filelock
@@ -13,12 +14,14 @@
1314
from codegen.git.repo_operator.repo_operator import RepoOperator
1415
from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags, ProjectConfig
1516
from codegen.sdk.core.codebase import Codebase
16-
from codemods.codemod import Codemod
1717
from tests.shared.codemod.constants import DIFF_FILEPATH
1818
from tests.shared.codemod.models import BASE_PATH, BASE_TMP_DIR, VERIFIED_CODEMOD_DIFFS, CodemodMetadata, Repo, Size
1919
from tests.shared.codemod.test_discovery import find_codemod_test_cases, find_repos, find_verified_codemod_cases
2020
from tests.shared.utils.recursion import set_recursion_limit
2121

22+
if TYPE_CHECKING:
23+
from codemods.codemod import Codemod
24+
2225
logger = logging.getLogger(__name__)
2326

2427
ONLY_STORE_CHANGED_DIFFS = True
@@ -201,7 +204,7 @@ def codemod(raw_codemod: type["Codemod"]):
201204

202205

203206
@pytest.fixture
204-
def verified_codemod(codemod_metadata: CodemodMetadata, expected: Path) -> YieldFixture[Codemod]:
207+
def verified_codemod(codemod_metadata: CodemodMetadata, expected: Path) -> YieldFixture["Codemod"]:
205208
# write the diff to the file
206209
diff_path = expected
207210
diff_path.parent.mkdir(parents=True, exist_ok=True)

uv.lock

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)