Skip to content

Commit 4a3f2d5

Browse files
authored
Fix module resolution bug (#190)
1 parent a700216 commit 4a3f2d5

File tree

6 files changed

+110
-14
lines changed

6 files changed

+110
-14
lines changed

src/codegen/sdk/core/external_module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@ class ExternalModule(
3535
"""
3636

3737
node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL
38+
_import: Import | None = None
3839

39-
def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name) -> None:
40+
def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name, import_node: Import | None = None) -> None:
4041
self.node_id = G.add_node(self)
4142
super().__init__(ts_node, file_node_id, G, None)
4243
self._name_node = import_name
4344
self.return_type = StubPlaceholder(parent=self)
4445
assert self._idx_key not in self.G._ext_module_idx
4546
self.G._ext_module_idx[self._idx_key] = self.node_id
47+
self._import = import_node
4648

4749
@property
4850
def _idx_key(self) -> str:
@@ -68,7 +70,7 @@ def from_import(cls, imp: Import) -> ExternalModule:
6870
Returns:
6971
ExternalModule: A new ExternalModule instance representing the external module.
7072
"""
71-
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node)
73+
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node, imp)
7274

7375
@property
7476
@reader
@@ -136,7 +138,7 @@ def viz(self) -> VizNode:
136138
@noapidoc
137139
@reader
138140
def resolve_attribute(self, name: str) -> ExternalModule | None:
139-
return self
141+
return self._import.resolve_attribute(name) or self
140142

141143
@noapidoc
142144
@commiter

src/codegen/sdk/core/import_resolution.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codegen.sdk.core.expressions.name import Name
1212
from codegen.sdk.core.external_module import ExternalModule
1313
from codegen.sdk.core.interfaces.chainable import Chainable
14+
from codegen.sdk.core.interfaces.has_attribute import HasAttribute
1415
from codegen.sdk.core.interfaces.usable import Usable
1516
from codegen.sdk.core.statements.import_statement import ImportStatement
1617
from codegen.sdk.enums import EdgeType, ImportType, NodeType
@@ -57,7 +58,7 @@ class ImportResolution(Generic[TSourceFile]):
5758

5859

5960
@apidoc
60-
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile]):
61+
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]):
6162
"""Represents a single symbol being imported.
6263
6364
For example, this is one `Import` in Python (and similar applies to Typescript, etc.):
@@ -115,7 +116,7 @@ def __rich_repr__(self) -> rich.repr.Result:
115116

116117
@noapidoc
117118
@abstractmethod
118-
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSourceFile] | None:
119+
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None:
119120
"""Resolves the import to a symbol defined outside the file.
120121
121122
Returns an ImportResolution object.
@@ -662,6 +663,17 @@ def remove_if_unused(self) -> None:
662663
):
663664
self.remove()
664665

666+
@noapidoc
667+
@reader
668+
def resolve_attribute(self, attribute: str) -> TSourceFile | None:
669+
# Handles implicit namespace imports in python
670+
if not isinstance(self._imported_symbol(), ExternalModule):
671+
return None
672+
resolved = self.resolve_import(add_module_name=attribute)
673+
if resolved:
674+
return resolved.symbol or resolved.from_file
675+
return None
676+
665677

666678
TImport = TypeVar("TImport", bound="Import")
667679

src/codegen/sdk/python/file.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from codegen.sdk.core.file import SourceFile
77
from codegen.sdk.core.interface import Interface
88
from codegen.sdk.enums import ImportType, ProgrammingLanguage
9-
from codegen.sdk.extensions.utils import iter_all_descendants
9+
from codegen.sdk.extensions.utils import cached_property, iter_all_descendants
1010
from codegen.sdk.python import PyAssignment
1111
from codegen.sdk.python.class_definition import PyClass
1212
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
@@ -20,6 +20,7 @@
2020

2121
if TYPE_CHECKING:
2222
from codegen.sdk.codebase.codebase_graph import CodebaseGraph
23+
from codegen.sdk.core.import_resolution import WildcardImport
2324
from codegen.sdk.python.symbol import PySymbol
2425

2526

@@ -173,3 +174,20 @@ def add_import_from_import_string(self, import_string: str) -> None:
173174
def remove_unused_exports(self) -> None:
174175
"""Removes unused exports from the file. NO-OP for python"""
175176
pass
177+
178+
@cached_property
179+
@noapidoc
180+
@reader(cache=True)
181+
def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]:
182+
"""Returns a dict mapping name => Symbol (or import) in this file that can be imported from
183+
another file.
184+
"""
185+
if self.name == "__init__":
186+
ret = {}
187+
if self.directory:
188+
for file in self.directory:
189+
if file.name == "__init__":
190+
continue
191+
ret[file.name] = file
192+
return ret
193+
return super().valid_import_names

src/codegen/sdk/python/import_resolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,13 @@ def imported_exports(self) -> list[Exportable]:
8282

8383
@noapidoc
8484
@reader
85-
def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFile] | None:
85+
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None:
8686
base_path = base_path or self.G.projects[0].base_path or ""
8787
module_source = self.module.source if self.module else ""
88-
88+
symbol_name = self.symbol_name.source if self.symbol_name else ""
89+
if add_module_name:
90+
module_source += f".{symbol_name}"
91+
symbol_name = add_module_name
8992
# If import is relative, convert to absolute path
9093
if module_source.startswith("."):
9194
module_source = self._relative_to_absolute_import(module_source)
@@ -99,7 +102,7 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil
99102
# `from a.b.c import foo`
100103
filepath = os.path.join(
101104
base_path,
102-
module_source.replace(".", "/") + "/" + self.symbol_name.source + ".py",
105+
module_source.replace(".", "/") + "/" + symbol_name + ".py",
103106
)
104107
if file := self.G.get_file(filepath):
105108
return ImportResolution(from_file=file, symbol=None, imports_file=True)
@@ -114,22 +117,22 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil
114117
filepath = module_source.replace(".", "/") + ".py"
115118
filepath = os.path.join(base_path, filepath)
116119
if file := self.G.get_file(filepath):
117-
symbol = file.get_node_by_name(self.symbol_name.source)
120+
symbol = file.get_node_by_name(symbol_name)
118121
return ImportResolution(from_file=file, symbol=symbol)
119122

120123
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
121124
filepath = filepath.replace(".py", "/__init__.py")
122125
if from_file := self.G.get_file(filepath):
123-
symbol = from_file.get_node_by_name(self.symbol_name.source)
126+
symbol = from_file.get_node_by_name(symbol_name)
124127
return ImportResolution(from_file=from_file, symbol=symbol)
125128

126129
# =====[ Case: Can't resolve the import ]=====
127130
if base_path == "":
128131
# Try to resolve with "src" as the base path
129-
return self.resolve_import(base_path="src")
132+
return self.resolve_import(base_path="src", add_module_name=add_module_name)
130133
if base_path == "src":
131134
# Try "test" next
132-
return self.resolve_import(base_path="test")
135+
return self.resolve_import(base_path="test", add_module_name=add_module_name)
133136

134137
# if not G_override:
135138
# for resolver in G.import_resolvers:

src/codegen/sdk/typescript/import_resolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None:
197197
return resolved_symbol
198198

199199
@reader
200-
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSFile] | None:
200+
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None:
201201
"""Resolves an import statement to its target file and symbol.
202202
203203
This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports,

tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,64 @@ def c_sym():
249249
assert "c_sym" in b_file.valid_symbol_names
250250
assert "a_sym" in c_file.valid_symbol_names
251251
assert "b_sym" in c_file.valid_symbol_names.keys()
252+
253+
254+
def test_import_resolution_nested_module(tmpdir: str) -> None:
255+
"""Tests import resolution works with nested module imports"""
256+
# language=python
257+
with get_codebase_session(
258+
tmpdir,
259+
files={
260+
"a/b/c.py": """
261+
def d():
262+
pass
263+
""",
264+
"consumer.py": """
265+
from a import b
266+
267+
b.c.d()
268+
""",
269+
},
270+
) as codebase:
271+
consumer_file: SourceFile = codebase.get_file("consumer.py")
272+
c_file: SourceFile = codebase.get_file("a/b/c.py")
273+
274+
# Verify import resolution
275+
assert len(consumer_file.imports) == 1
276+
277+
# Verify function call resolution
278+
d_func = c_file.get_function("d")
279+
call_sites = d_func.call_sites
280+
assert len(call_sites) == 1
281+
assert call_sites[0].file == consumer_file
282+
283+
284+
def test_import_resolution_nested_module_init(tmpdir: str) -> None:
285+
"""Tests import resolution works with nested module imports"""
286+
# language=python
287+
with get_codebase_session(
288+
tmpdir,
289+
files={
290+
"a/b/c.py": """
291+
def d():
292+
pass
293+
""",
294+
"a/b/__init__.py": """""",
295+
"consumer.py": """
296+
from a import b
297+
298+
b.c.d()
299+
""",
300+
},
301+
) as codebase:
302+
consumer_file: SourceFile = codebase.get_file("consumer.py")
303+
c_file: SourceFile = codebase.get_file("a/b/c.py")
304+
305+
# Verify import resolution
306+
assert len(consumer_file.imports) == 1
307+
308+
# Verify function call resolution
309+
d_func = c_file.get_function("d")
310+
call_sites = d_func.call_sites
311+
assert len(call_sites) == 1
312+
assert call_sites[0].file == consumer_file

0 commit comments

Comments
 (0)