Skip to content

Commit ee3de3b

Browse files
tomcodgentkfossfrainfreeze
authored andcommitted
[CG-10935] fix: issues with assigment (#737)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: tomcodgen <[email protected]> Co-authored-by: tomcodegen <[email protected]>
1 parent e756e15 commit ee3de3b

25 files changed

+390
-41
lines changed

src/codegen/sdk/core/expressions/name.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from codegen.sdk.core.autocommit import reader, writer
66
from codegen.sdk.core.dataclasses.usage import UsageKind
77
from codegen.sdk.core.expressions.expression import Expression
8+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
89
from codegen.sdk.core.interfaces.resolvable import Resolvable
910
from codegen.sdk.extensions.autocommit import commiter
1011
from codegen.shared.decorators.docs import apidoc, noapidoc
1112

1213
if TYPE_CHECKING:
14+
from codegen.sdk.core.import_resolution import Import, WildcardImport
1315
from codegen.sdk.core.interfaces.has_name import HasName
14-
16+
from codegen.sdk.core.symbol import Symbol
1517

1618
Parent = TypeVar("Parent", bound="Expression")
1719

@@ -29,10 +31,9 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]):
2931
@override
3032
def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]:
3133
"""Resolve the types used by this symbol."""
32-
if used := self.resolve_name(self.source, self.start_byte):
34+
for used in self.resolve_name(self.source, self.start_byte):
3335
yield from self.with_resolution_frame(used)
3436

35-
@noapidoc
3637
@commiter
3738
def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None:
3839
"""Compute the dependencies of the export object."""
@@ -48,3 +49,25 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName |
4849
def rename_if_matching(self, old: str, new: str):
4950
if self.source == old:
5051
self.edit(new)
52+
53+
@noapidoc
54+
@reader
55+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]:
56+
resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None)
57+
if resolved_name:
58+
yield resolved_name
59+
else:
60+
return
61+
62+
if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)):
63+
top_of_conditional = conditional_parent.start_byte
64+
if self.parent_of_type(ConditionalBlock) == conditional_parent:
65+
# Use in the same block, should only depend on the inside of the block
66+
return
67+
for other_conditional in conditional_parent.other_possible_blocks:
68+
if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None):
69+
if cond_name.start_byte >= other_conditional.start_byte:
70+
yield cond_name
71+
top_of_conditional = min(top_of_conditional, other_conditional.start_byte)
72+
73+
yield from self.resolve_name(name, top_of_conditional, strict=False)

src/codegen/sdk/core/file.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import resource
44
import sys
55
from abc import abstractmethod
6-
from collections.abc import Sequence
6+
from collections.abc import Generator, Sequence
77
from functools import cached_property
88
from os import PathLike
99
from pathlib import Path
@@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None:
744744
Returns:
745745
Symbol | None: The found symbol, or None if not found.
746746
"""
747-
if symbol := self.resolve_name(name, self.end_byte):
747+
if symbol := next(self.resolve_name(name, self.end_byte), None):
748748
if isinstance(symbol, Symbol):
749749
return symbol
750750
return next((x for x in self.symbols if x.name == name), None)
@@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None:
819819
Returns:
820820
TClass | None: The matching Class object if found, None otherwise.
821821
"""
822-
if symbol := self.resolve_name(name, self.end_byte):
822+
if symbol := next(self.resolve_name(name, self.end_byte), None):
823823
if isinstance(symbol, Class):
824824
return symbol
825825

@@ -880,13 +880,41 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp
880880

881881
@noapidoc
882882
@reader
883-
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
883+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
884+
"""Resolves a name to a symbol, import, or wildcard import within the file's scope.
885+
886+
Performs name resolution by first checking the file's valid symbols and imports. When a start_byte
887+
is provided, ensures proper scope handling by only resolving to symbols that are defined before
888+
that position in the file.
889+
890+
Args:
891+
name (str): The name to resolve.
892+
start_byte (int | None): If provided, only resolves to symbols defined before this byte position
893+
in the file. Used for proper scope handling. Defaults to None.
894+
strict (bool): When True and using start_byte, only yields symbols if found in the correct scope.
895+
When False, allows falling back to global scope. Defaults to True.
896+
897+
Yields:
898+
Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches
899+
the name and scope requirements. Yields at most one result.
900+
"""
884901
if resolved := self.valid_symbol_names.get(name):
902+
# If we have a start_byte and the resolved symbol is after it,
903+
# we need to look for earlier definitions of the symbol
885904
if start_byte is not None and resolved.end_byte > start_byte:
886-
for symbol in self.symbols:
905+
# Search backwards through symbols to find the most recent definition
906+
# that comes before our start_byte position
907+
for symbol in reversed(self.symbols):
887908
if symbol.start_byte <= start_byte and symbol.name == name:
888-
return symbol
889-
return resolved
909+
yield symbol
910+
return
911+
# If strict mode and no valid symbol found, return nothing
912+
if not strict:
913+
return
914+
# Either no start_byte constraint or symbol is before start_byte
915+
yield resolved
916+
return
917+
return
890918

891919
@property
892920
@reader

src/codegen/sdk/core/function.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,14 @@ def is_async(self) -> bool:
141141

142142
@noapidoc
143143
@reader
144-
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
144+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
145145
from codegen.sdk.core.class_definition import Class
146146

147147
for symbol in self.valid_symbol_names:
148148
if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte):
149-
return symbol
150-
return super().resolve_name(name, start_byte)
149+
yield symbol
150+
return
151+
yield from super().resolve_name(name, start_byte)
151152

152153
@cached_property
153154
@noapidoc
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from abc import ABC, abstractmethod
2+
from collections.abc import Sequence
3+
4+
from codegen.sdk.core.statements.statement import Statement
5+
6+
7+
class ConditionalBlock(Statement, ABC):
8+
"""An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect."""
9+
10+
@property
11+
@abstractmethod
12+
def other_possible_blocks(self) -> Sequence["ConditionalBlock"]:
13+
"""Should return all other "branches" that might be executed instead."""
14+
15+
@property
16+
def end_byte_for_condition_block(self) -> int:
17+
return self.end_byte

src/codegen/sdk/core/interfaces/editable.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,10 +1003,11 @@ def viz(self) -> VizNode:
10031003

10041004
@noapidoc
10051005
@reader
1006-
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
1006+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
10071007
if self.parent is not None:
1008-
return self.parent.resolve_name(name, start_byte or self.start_byte)
1009-
return self.file.resolve_name(name, start_byte or self.start_byte)
1008+
yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict)
1009+
else:
1010+
yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict)
10101011

10111012
@cached_property
10121013
@noapidoc

src/codegen/sdk/core/statements/catch_statement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING, Generic, Self, TypeVar
44

5+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
56
from codegen.sdk.core.statements.block_statement import BlockStatement
67
from codegen.sdk.extensions.autocommit import commiter
78
from codegen.shared.decorators.docs import apidoc, noapidoc
@@ -17,7 +18,7 @@
1718

1819

1920
@apidoc
20-
class CatchStatement(BlockStatement[Parent], Generic[Parent]):
21+
class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]):
2122
"""Abstract representation catch clause.
2223
2324
Attributes:

src/codegen/sdk/core/statements/for_loop_statement.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from codegen.shared.decorators.docs import apidoc, noapidoc
1313

1414
if TYPE_CHECKING:
15+
from collections.abc import Generator
16+
1517
from codegen.sdk.core.detached_symbols.code_block import CodeBlock
1618
from codegen.sdk.core.expressions import Expression
1719
from codegen.sdk.core.import_resolution import Import, WildcardImport
@@ -36,19 +38,23 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):
3638

3739
@noapidoc
3840
@reader
39-
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
41+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
4042
if self.item and isinstance(self.iterable, Chainable):
4143
if start_byte is None or start_byte > self.iterable.end_byte:
4244
if name == self.item:
4345
for frame in self.iterable.resolved_type_frames:
4446
if frame.generics:
45-
return next(iter(frame.generics.values()))
46-
return frame.top.node
47+
yield next(iter(frame.generics.values()))
48+
return
49+
yield frame.top.node
50+
return
4751
elif isinstance(self.item, Collection):
4852
for idx, item in enumerate(self.item):
4953
if item == name:
5054
for frame in self.iterable.resolved_type_frames:
5155
if frame.generics and len(frame.generics) > idx:
52-
return list(frame.generics.values())[idx]
53-
return frame.top.node
54-
return super().resolve_name(name, start_byte)
56+
yield list(frame.generics.values())[idx]
57+
return
58+
yield frame.top.node
59+
return
60+
yield from super().resolve_name(name, start_byte)

src/codegen/sdk/core/statements/if_block_statement.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
from codegen.sdk.core.autocommit import reader, writer
99
from codegen.sdk.core.dataclasses.usage import UsageKind
1010
from codegen.sdk.core.function import Function
11+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
1112
from codegen.sdk.core.statements.statement import Statement, StatementType
1213
from codegen.sdk.extensions.autocommit import commiter
1314
from codegen.shared.decorators.docs import apidoc, noapidoc
1415

1516
if TYPE_CHECKING:
17+
from collections.abc import Sequence
18+
1619
from codegen.sdk.core.detached_symbols.code_block import CodeBlock
1720
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
1821
from codegen.sdk.core.expressions import Expression
@@ -26,7 +29,7 @@
2629

2730

2831
@apidoc
29-
class IfBlockStatement(Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]):
32+
class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]):
3033
"""Abstract representation of the if/elif/else if/else statement block.
3134
3235
For example, if there is a code block like:
@@ -271,3 +274,26 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -
271274
self.remove_byte_range(self.ts_node.start_byte, remove_end)
272275
else:
273276
self.remove()
277+
278+
@property
279+
def other_possible_blocks(self) -> Sequence[ConditionalBlock]:
280+
if self.is_if_statement:
281+
return self._main_if_block.alternative_blocks
282+
elif self.is_elif_statement:
283+
main = self._main_if_block
284+
statements = [main]
285+
if main.else_statement:
286+
statements.append(main.else_statement)
287+
for statement in main.elif_statements:
288+
if statement != self:
289+
statements.append(statement)
290+
return statements
291+
else:
292+
main = self._main_if_block
293+
return [main, *main.elif_statements]
294+
295+
@property
296+
def end_byte_for_condition_block(self) -> int:
297+
if self.is_if_statement:
298+
return self.consequence_block.end_byte
299+
return self.end_byte

src/codegen/sdk/core/statements/switch_case.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING, Generic, Self, TypeVar
44

5+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
56
from codegen.sdk.core.statements.block_statement import BlockStatement
67
from codegen.sdk.extensions.autocommit import commiter
78
from codegen.shared.decorators.docs import apidoc, noapidoc
@@ -18,7 +19,7 @@
1819

1920

2021
@apidoc
21-
class SwitchCase(BlockStatement[Parent], Generic[Parent]):
22+
class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]):
2223
"""Abstract representation for a switch case.
2324
2425
Attributes:
@@ -34,3 +35,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa
3435
if self.condition:
3536
self.condition._compute_dependencies(usage_type, dest)
3637
super()._compute_dependencies(usage_type, dest)
38+
39+
@property
40+
def other_possible_blocks(self) -> list[ConditionalBlock]:
41+
return [case for case in self.parent.cases if case != self]

src/codegen/sdk/core/statements/try_catch_statement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import ABC
44
from typing import TYPE_CHECKING, Generic, TypeVar
55

6+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
67
from codegen.sdk.core.interfaces.has_block import HasBlock
78
from codegen.sdk.core.statements.block_statement import BlockStatement
89
from codegen.sdk.core.statements.statement import StatementType
@@ -16,7 +17,7 @@
1617

1718

1819
@apidoc
19-
class TryCatchStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):
20+
class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):
2021
"""Abstract representation of the try catch statement block.
2122
2223
Attributes:

src/codegen/sdk/python/function.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from codegen.shared.logging.get_logger import get_logger
2020

2121
if TYPE_CHECKING:
22+
from collections.abc import Generator
23+
2224
from tree_sitter import Node as TSNode
2325

2426
from codegen.sdk.codebase.codebase_context import CodebaseContext
@@ -119,15 +121,17 @@ def is_class_method(self) -> bool:
119121

120122
@noapidoc
121123
@reader
122-
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
124+
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
123125
if self.is_method:
124126
if not self.is_static_method:
125127
if len(self.parameters.symbols) > 0:
126128
if name == self.parameters[0].name:
127-
return self.parent_class
129+
yield self.parent_class
130+
return
128131
if name == "super()":
129-
return self.parent_class
130-
return super().resolve_name(name, start_byte)
132+
yield self.parent_class
133+
return
134+
yield from super().resolve_name(name, start_byte)
131135

132136
@noapidoc
133137
@commiter

src/codegen/sdk/python/import_resolution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,11 @@ def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str)
211211
"""
212212
for resolve_path in resolve_paths:
213213
filepath_new: str = os.path.join(resolve_path, filepath)
214-
if file := self.ctx.get_file(filepath_new):
214+
try:
215+
file = self.ctx.get_file(filepath_new)
216+
except AssertionError as e:
217+
file = None
218+
if file:
215219
return file
216220

217221
return None

src/codegen/sdk/python/statements/catch_statement.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tree_sitter import Node as PyNode
1212

1313
from codegen.sdk.codebase.codebase_context import CodebaseContext
14+
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
1415
from codegen.sdk.core.node_id_factory import NodeId
1516

1617

@@ -26,3 +27,7 @@ class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement):
2627
def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None:
2728
super().__init__(ts_node, file_node_id, ctx, parent, pos)
2829
self.condition = self.children[0]
30+
31+
@property
32+
def other_possible_blocks(self) -> list[ConditionalBlock]:
33+
return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent]

src/codegen/sdk/python/statements/if_block_statement.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from codegen.sdk.core.node_id_factory import NodeId
1515
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
1616

17-
1817
Parent = TypeVar("Parent", bound="PyCodeBlock")
1918

2019

0 commit comments

Comments
 (0)