Skip to content

Fix module resolution bug #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/codegen/sdk/core/external_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ class ExternalModule(
"""

node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL
_import: Import | None = None

def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name) -> None:
def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name, import_node: Import | None = None) -> None:
self.node_id = G.add_node(self)
super().__init__(ts_node, file_node_id, G, None)
self._name_node = import_name
self.return_type = StubPlaceholder(parent=self)
assert self._idx_key not in self.G._ext_module_idx
self.G._ext_module_idx[self._idx_key] = self.node_id
self._import = import_node

@property
def _idx_key(self) -> str:
Expand All @@ -68,7 +70,7 @@ def from_import(cls, imp: Import) -> ExternalModule:
Returns:
ExternalModule: A new ExternalModule instance representing the external module.
"""
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node)
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node, imp)

@property
@reader
Expand Down Expand Up @@ -136,7 +138,7 @@ def viz(self) -> VizNode:
@noapidoc
@reader
def resolve_attribute(self, name: str) -> ExternalModule | None:
return self
return self._import.resolve_attribute(name) or self

@noapidoc
@commiter
Expand Down
16 changes: 14 additions & 2 deletions src/codegen/sdk/core/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from codegen.sdk.core.expressions.name import Name
from codegen.sdk.core.external_module import ExternalModule
from codegen.sdk.core.interfaces.chainable import Chainable
from codegen.sdk.core.interfaces.has_attribute import HasAttribute
from codegen.sdk.core.interfaces.usable import Usable
from codegen.sdk.core.statements.import_statement import ImportStatement
from codegen.sdk.enums import EdgeType, ImportType, NodeType
Expand Down Expand Up @@ -57,7 +58,7 @@ class ImportResolution(Generic[TSourceFile]):


@apidoc
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile]):
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]):
"""Represents a single symbol being imported.

For example, this is one `Import` in Python (and similar applies to Typescript, etc.):
Expand Down Expand Up @@ -115,7 +116,7 @@ def __rich_repr__(self) -> rich.repr.Result:

@noapidoc
@abstractmethod
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSourceFile] | None:
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None:
"""Resolves the import to a symbol defined outside the file.

Returns an ImportResolution object.
Expand Down Expand Up @@ -662,6 +663,17 @@ def remove_if_unused(self) -> None:
):
self.remove()

@noapidoc
@reader
def resolve_attribute(self, attribute: str) -> TSourceFile | None:
# Handles implicit namespace imports in python
if not isinstance(self._imported_symbol(), ExternalModule):
return None
resolved = self.resolve_import(add_module_name=attribute)
if resolved:
return resolved.symbol or resolved.from_file
return None


TImport = TypeVar("TImport", bound="Import")

Expand Down
20 changes: 19 additions & 1 deletion src/codegen/sdk/python/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interface import Interface
from codegen.sdk.enums import ImportType, ProgrammingLanguage
from codegen.sdk.extensions.utils import iter_all_descendants
from codegen.sdk.extensions.utils import cached_property, iter_all_descendants
from codegen.sdk.python import PyAssignment
from codegen.sdk.python.class_definition import PyClass
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
Expand All @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from codegen.sdk.codebase.codebase_graph import CodebaseGraph
from codegen.sdk.core.import_resolution import WildcardImport
from codegen.sdk.python.symbol import PySymbol


Expand Down Expand Up @@ -173,3 +174,20 @@ def add_import_from_import_string(self, import_string: str) -> None:
def remove_unused_exports(self) -> None:
"""Removes unused exports from the file. NO-OP for python"""
pass

@cached_property
@noapidoc
@reader(cache=True)
def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]:
"""Returns a dict mapping name => Symbol (or import) in this file that can be imported from
another file.
"""
if self.name == "__init__":
ret = {}
if self.directory:
for file in self.directory:
if file.name == "__init__":
continue
ret[file.name] = file
return ret
return super().valid_import_names
17 changes: 10 additions & 7 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ def imported_exports(self) -> list[Exportable]:

@noapidoc
@reader
def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFile] | None:
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None:
base_path = base_path or self.G.projects[0].base_path or ""
module_source = self.module.source if self.module else ""

symbol_name = self.symbol_name.source if self.symbol_name else ""
if add_module_name:
module_source += f".{symbol_name}"
symbol_name = add_module_name
# If import is relative, convert to absolute path
if module_source.startswith("."):
module_source = self._relative_to_absolute_import(module_source)
Expand All @@ -99,7 +102,7 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil
# `from a.b.c import foo`
filepath = os.path.join(
base_path,
module_source.replace(".", "/") + "/" + self.symbol_name.source + ".py",
module_source.replace(".", "/") + "/" + symbol_name + ".py",
)
if file := self.G.get_file(filepath):
return ImportResolution(from_file=file, symbol=None, imports_file=True)
Expand All @@ -114,22 +117,22 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil
filepath = module_source.replace(".", "/") + ".py"
filepath = os.path.join(base_path, filepath)
if file := self.G.get_file(filepath):
symbol = file.get_node_by_name(self.symbol_name.source)
symbol = file.get_node_by_name(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.G.get_file(filepath):
symbol = from_file.get_node_by_name(self.symbol_name.source)
symbol = from_file.get_node_by_name(symbol_name)
return ImportResolution(from_file=from_file, symbol=symbol)

# =====[ Case: Can't resolve the import ]=====
if base_path == "":
# Try to resolve with "src" as the base path
return self.resolve_import(base_path="src")
return self.resolve_import(base_path="src", add_module_name=add_module_name)
if base_path == "src":
# Try "test" next
return self.resolve_import(base_path="test")
return self.resolve_import(base_path="test", add_module_name=add_module_name)

# if not G_override:
# for resolver in G.import_resolvers:
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/sdk/typescript/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None:
return resolved_symbol

@reader
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSFile] | None:
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None:
"""Resolves an import statement to its target file and symbol.

This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,64 @@ def c_sym():
assert "c_sym" in b_file.valid_symbol_names
assert "a_sym" in c_file.valid_symbol_names
assert "b_sym" in c_file.valid_symbol_names.keys()


def test_import_resolution_nested_module(tmpdir: str) -> None:
"""Tests import resolution works with nested module imports"""
# language=python
with get_codebase_session(
tmpdir,
files={
"a/b/c.py": """
def d():
pass
""",
"consumer.py": """
from a import b

b.c.d()
""",
},
) as codebase:
consumer_file: SourceFile = codebase.get_file("consumer.py")
c_file: SourceFile = codebase.get_file("a/b/c.py")

# Verify import resolution
assert len(consumer_file.imports) == 1

# Verify function call resolution
d_func = c_file.get_function("d")
call_sites = d_func.call_sites
assert len(call_sites) == 1
assert call_sites[0].file == consumer_file


def test_import_resolution_nested_module_init(tmpdir: str) -> None:
"""Tests import resolution works with nested module imports"""
# language=python
with get_codebase_session(
tmpdir,
files={
"a/b/c.py": """
def d():
pass
""",
"a/b/__init__.py": """""",
"consumer.py": """
from a import b

b.c.d()
""",
},
) as codebase:
consumer_file: SourceFile = codebase.get_file("consumer.py")
c_file: SourceFile = codebase.get_file("a/b/c.py")

# Verify import resolution
assert len(consumer_file.imports) == 1

# Verify function call resolution
d_func = c_file.get_function("d")
call_sites = d_func.call_sites
assert len(call_sites) == 1
assert call_sites[0].file == consumer_file