Skip to content

[CG-10888] fix: Wildcard resolution #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/codegen/sdk/python/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,28 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P
ret[file.name] = file
return ret
return super().valid_import_names

def get_node_from_wildcard_chain(self, symbol_name: str):
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name)

return node

def get_node_wildcard_resolves_for(self, symbol_name: str):
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name):
node = wildcard_import

return node
40 changes: 37 additions & 3 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self, override

from codegen.sdk.core.autocommit import reader
from codegen.sdk.core.expressions import Name
from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution
from codegen.sdk.enums import ImportType, NodeType
from codegen.sdk.extensions.resolution import ResolutionStack
from codegen.shared.decorators.docs import noapidoc, py_apidoc

if TYPE_CHECKING:
from collections.abc import Generator

from tree_sitter import Node as TSNode

from codegen.sdk.codebase.codebase_context import CodebaseContext
Expand All @@ -28,6 +31,10 @@
class PyImport(Import["PyFile"]):
"""Extends Import for Python codebases."""

def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN):
super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type)
self._requesting_names = set()

@reader
def is_module_import(self) -> bool:
"""Determines if the import is a module-level or wildcard import.
Expand Down Expand Up @@ -55,7 +62,7 @@

resolved_symbol = self.resolved_symbol
if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE:
return self.alias.source

Check failure on line 65 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "Editable[Any] | None" has no attribute "source" [union-attr]
return None

@property
Expand All @@ -76,9 +83,9 @@
return []

if not self.is_module_import():
return [self.imported_symbol]

Check failure on line 86 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: List item 0 has incompatible type "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]"; expected "Exportable[Any]" [list-item]

return self.imported_symbol.symbols + self.imported_symbol.imports

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "Symbol[Any, Any]" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "symbols" [union-attr]

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "ExternalModule" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "symbols" [union-attr]

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "Import[Any]" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "symbols" [union-attr]

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "Symbol[Any, Any]" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "imports" [union-attr]

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "ExternalModule" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "imports" [union-attr]

Check failure on line 88 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "Import[Any]" of "Symbol[Any, Any] | ExternalModule | PyFile | Import[Any]" has no attribute "imports" [union-attr]

@noapidoc
@reader
Expand All @@ -104,8 +111,8 @@
base_path,
module_source.replace(".", "/") + "/" + symbol_name + ".py",
)
if file := self.ctx.get_file(filepath):

Check failure on line 114 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "get_file" of "CodebaseContext" has incompatible type "str"; expected "PathLike[Any]" [arg-type]
return ImportResolution(from_file=file, symbol=None, imports_file=True)

Check failure on line 115 in src/codegen/sdk/python/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument "from_file" to "ImportResolution" has incompatible type "SourceFile[Any, Any, Any, Any, Any, Any]"; expected "PyFile | None" [arg-type]

filepath = filepath.replace(".py", "/__init__.py")
if file := self.ctx.get_file(filepath):
Expand All @@ -117,13 +124,13 @@
filepath = module_source.replace(".", "/") + ".py"
filepath = os.path.join(base_path, filepath)
if file := self.ctx.get_file(filepath):
symbol = file.get_node_by_name(symbol_name)
symbol = file.get_node_wildcard_resolves_for(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.ctx.get_file(filepath):
symbol = from_file.get_node_by_name(symbol_name)
symbol = from_file.get_node_wildcard_resolves_for(symbol_name)
return ImportResolution(from_file=from_file, symbol=symbol)

# =====[ Case: Can't resolve the import ]=====
Expand Down Expand Up @@ -232,6 +239,33 @@
imports.append(imp)
return imports

@reader
@noapidoc
@override
def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]:
"""Resolve the types used by this import."""
ix_seen = set()

aliased = self.is_aliased_import()
if imported := self._imported_symbol(resolve_exports=True):
if isinstance(imported, PyImport) and imported.is_wildcard_import:
imported.set_requesting_names(self)
yield from self.with_resolution_frame(imported, direct=False, aliased=aliased)
else:
yield ResolutionStack(self, aliased=aliased)

if self.is_wildcard_import():
for name, wildcard_import in self.names:
if name in self._requesting_names:
yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames]

@noapidoc
def set_requesting_names(self, requester: PyImport):
if requester.is_wildcard_import():
self._requesting_names.update(requester._requesting_names)
else:
self._requesting_names.add(requester.name)

@property
@reader
def import_specifier(self) -> Editable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,206 @@ def func_1():
assert call_site.file == consumer_file


def test_import_resolution_init_wildcard(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly"""
# language=python
content1 = """TEST_CONST=2
foo=9
"""
content2 = """from testdir.test1 import *
bar=foo
test=TEST_CONST"""
content3 = """from testdir import TEST_CONST
test3=TEST_CONST"""
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
file1: SourceFile = codebase.get_file("testdir/test1.py")
file2: SourceFile = codebase.get_file("testdir/__init__.py")
file3: SourceFile = codebase.get_file("test3.py")

symb = file1.get_symbol("TEST_CONST")
test = file2.get_symbol("test")
test3 = file3.get_symbol("test3")
test3_import = file3.get_import("TEST_CONST")

assert len(symb.usages) == 3
assert symb.symbol_usages == [test, test3, test3_import]


def test_import_resolution_wildcard_func(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly"""
# language=python
content1 = """
def foo():
pass
def bar():
pass
"""
content2 = """
from testa import *

foo()
"""

with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase:
testa: SourceFile = codebase.get_file("testa.py")
testb: SourceFile = codebase.get_file("testb.py")

foo = testa.get_symbol("foo")
bar = testa.get_symbol("bar")
assert len(foo.usages) == 1
assert len(foo.call_sites) == 1

assert len(bar.usages) == 0
assert len(bar.call_sites) == 0
assert len(testb.function_calls) == 1


def test_import_resolution_chaining_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python
content1 = """TEST_CONST=2
foo=9
"""
content2 = """from testdir.test1 import *
bar=foo
test=TEST_CONST"""
content3 = """from testdir import *
test3=TEST_CONST"""
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
file1: SourceFile = codebase.get_file("testdir/test1.py")
file2: SourceFile = codebase.get_file("testdir/__init__.py")
file3: SourceFile = codebase.get_file("test3.py")

symb = file1.get_symbol("TEST_CONST")
test = file2.get_symbol("test")
bar = file2.get_symbol("bar")
mid_import = file2.get_import("testdir.test1")
test3 = file3.get_symbol("test3")

assert len(symb.usages) == 2
assert symb.symbol_usages == [test, test3]
assert mid_import.symbol_usages == [test, bar, test3]


def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test/nest/nest2/test1.py": """test_const=5
test_not_used=2
test_used_parent=5
""",
"test/nest/nest2/__init__.py": """from .test1 import *
t1=test_used_parent
""",
"test/nest/__init__.py": """from .nest2 import *""",
"test/__init__.py": """from .nest import *""",
"main.py": """
from test import *
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = deepest_layer.get_symbol("test_const")
test_not_used = deepest_layer.get_symbol("test_not_used")
test_used_parent = deepest_layer.get_symbol("test_used_parent")

assert len(test_const.usages) == 1
assert test_const.usages[0].usage_symbol == main_test
assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test1.py": """
test_const=5
test_not_used=2
test_used_parent=5
""",
"test2.py": """from test1 import *
t1=test_used_parent
""",
"test3.py": """from test2 import *""",
"test4.py": """from test3 import *""",
"main.py": """
from test4 import *
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
furthest_layer: SourceFile = codebase.get_file("test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test2.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = furthest_layer.get_symbol("test_const")
test_not_used = furthest_layer.get_symbol("test_not_used")
test_used_parent = furthest_layer.get_symbol("test_used_parent")

assert len(test_const.usages) == 1
assert test_const.usages[0].usage_symbol == main_test
assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test/nest/nest2/test1.py": """test_const=5
test_not_used=2
test_used_parent=5
""",
"test/nest/nest2/__init__.py": """from .test1 import *
t1=test_used_parent
""",
"test/nest/__init__.py": """from .nest2 import *""",
"test/__init__.py": """from .nest import *""",
"main.py": """
from test import test_const
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")
test_nest: SourceFile = codebase.get_file("test/__init__.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = deepest_layer.get_symbol("test_const")
test_not_used = deepest_layer.get_symbol("test_not_used")
test_used_parent = deepest_layer.get_symbol("test_used_parent")

test_const_imp = main.get_import("test_const")
test_const_imp_2 = test_nest.get_import(".nest")

assert len(test_const.usages) == 3
assert test_const.usages[0].usage_symbol == main_test
assert test_const.usages[1].usage_symbol == test_const_imp
assert test_const.usages[2].usage_symbol == test_const_imp_2

assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_circular(tmpdir: str) -> None:
"""Tests function.usages returns usages from file imports"""
# language=python
Expand Down Expand Up @@ -343,7 +543,7 @@ def some_func():
assert len(some_func.symbol_usages) > 0


def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None:
def test_import_wildcard_preserves_import_resultion(tmpdir: str) -> None:
"""Tests importing from a file that contains a wildcard import doesn't break further resolution.
This could occur depending on to_resolve ordering, if the outer file is processed first _wildcards will not be filled in time.
"""
Expand All @@ -367,4 +567,4 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None:
) as codebase:
mainfile: SourceFile = codebase.get_file("file.py")

assert len(mainfile.ctx.edges) == 5
assert len(mainfile.ctx.edges) == 12
Loading