Skip to content

[CG-10888] fix: Wildcard resolution #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/codegen/sdk/core/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
imports_file: bool = False # True when we import the entire file (e.g. `from a.b.c import foo`)


TSourceFile = TypeVar("TSourceFile", bound="SourceFile")

Check failure on line 57 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot redefine "TSourceFile" as a type variable [misc]

Check failure on line 57 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid assignment target [misc]


@apidoc
Expand All @@ -76,7 +76,7 @@
module: Editable | None
symbol_name: Editable | None
alias: Editable | None
node_type: ClassVar[Literal[NodeType.IMPORT]] = NodeType.IMPORT

Check failure on line 79 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "Expression") with class variable [misc]
import_type: ImportType
import_statement: ImportStatement

Expand All @@ -96,7 +96,7 @@
self.module = self.ctx.parser.parse_expression(module_node, self.file_node_id, ctx, self, default=Name) if module_node else None
self.alias = self.ctx.parser.parse_expression(alias_node, self.file_node_id, ctx, self, default=Name) if alias_node else None
self.symbol_name = self.ctx.parser.parse_expression(name_node, self.file_node_id, ctx, self, default=Name) if name_node else None
self._name_node = self._parse_expression(name_node, default=Name)

Check failure on line 99 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "Expression[Import[TSourceFile]] | None", variable has type "Name[Any] | ChainedAttribute[Any, Any, Any] | DefinedName[Any] | None") [assignment]
self.import_type = import_type

def __rich_repr__(self) -> rich.repr.Result:
Expand All @@ -109,7 +109,7 @@
yield "wildcard", self.is_wildcard_import(), False
yield from super().__rich_repr__()

__rich_repr__.angular = ANGULAR_STYLE

Check failure on line 112 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Callable[[Import[TSourceFile]], Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]]" has no attribute "angular" [attr-defined]

@noapidoc
@abstractmethod
Expand Down Expand Up @@ -140,7 +140,7 @@
# =====[ Case: Can resolve the filepath ]=====
elif resolution.symbol:
if resolution.symbol.node_id == self.node_id:
return [] # Circular to self

Check failure on line 143 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: No return value expected [return-value]
self.ctx.add_edge(
self.node_id,
resolution.symbol.node_id,
Expand All @@ -148,7 +148,7 @@
)

elif resolution.imports_file:
self.ctx.add_edge(self.node_id, resolution.from_file.node_id, type=EdgeType.IMPORT_SYMBOL_RESOLUTION)

Check failure on line 151 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "TSourceFile | None" has no attribute "node_id" [union-attr]
# for symbol in resolution.from_file.symbols:
# usage = SymbolUsage(parent_symbol_name=self.name, child_symbol_name=self.name, type=SymbolUsageType.IMPORTED, match=self, usage_type=UsageType.DIRECT)
# self.ctx.add_edge(self.node_id, symbol.node_id, type=EdgeType.SYMBOL_USAGE, usage=usage)
Expand All @@ -160,7 +160,7 @@
# - an indirect import of an external module
# TODO: add as external module only if it resolves to an external module from resolution.from_file
# Solution: return the resolution object to be processed in a separate loop in `compute_codebase_graph`
return []

Check failure on line 163 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: No return value expected [return-value]

@property
@reader
Expand Down Expand Up @@ -271,9 +271,9 @@
elif imported.node_type == NodeType.EXTERNAL:
return None
elif imported.__class__.__name__.endswith("SourceFile"): # TODO - this is a hack for when you import a full file/module
return imported

Check failure on line 274 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "Symbol[Any, Any] | TSourceFile | Import[Any]", expected "TSourceFile | None") [return-value]
else:
return imported.file

Check failure on line 276 in src/codegen/sdk/core/import_resolution.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "SourceFile[Any, Any, Any, Any, Any, Any]", expected "TSourceFile | None") [return-value]

@property
@reader
Expand Down Expand Up @@ -324,6 +324,7 @@
"""Returns the symbol directly being imported, including an indirect import and an External
Module.
"""
from codegen.sdk.python.file import PyFile
from codegen.sdk.typescript.file import TSFile

symbol = next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.IMPORT_SYMBOL_RESOLUTION, sort=False)), None)
Expand All @@ -341,6 +342,14 @@
if self.import_type == ImportType.NAMED_EXPORT:
if export := symbol.valid_import_names.get(name, None):
return export
elif resolve_exports and isinstance(symbol, PyFile):
name = self.symbol_name.source if self.symbol_name else ""
if self.import_type == ImportType.NAMED_EXPORT:
if symbol.name == name:
return symbol
if imp := symbol.valid_import_names.get(name, None):
return imp

if symbol is not self:
return symbol

Expand Down Expand Up @@ -632,6 +641,11 @@
# if used_frame.parent_frame:
# used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx)
# else:
if isinstance(self, Import) and self.import_type == ImportType.NAMED_EXPORT:
# It could be a wildcard import downstream, hence we have to pop the cache
if file := self.from_file:
file.invalidate()

for used_frame in self.resolved_type_frames:
if used_frame.parent_frame:
used_frame.parent_frame.add_usage(self._unique_node, UsageKind.IMPORTED, self, self.ctx)
Expand Down
54 changes: 54 additions & 0 deletions src/codegen/sdk/python/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,57 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P
ret[file.name] = file
return ret
return super().valid_import_names

@noapidoc
def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None:
"""Recursively searches for a symbol through wildcard import chains.

Attempts to find a symbol by name in the current file, and if not found, recursively searches
through any wildcard imports (from x import *) to find the symbol in imported modules.

Args:
symbol_name (str): The name of the symbol to search for.

Returns:
PySymbol | None: The found symbol if it exists in this file or any of its wildcard
imports, None otherwise.
"""
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name)

return node

@noapidoc
def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None:
"""Finds the wildcard import that resolves a given symbol name.

Searches for a symbol by name, first in the current file, then through wildcard imports.
Unlike get_node_from_wildcard_chain, this returns the wildcard import that contains
the symbol rather than the symbol itself.

Args:
symbol_name (str): The name of the symbol to search for.

Returns:
PyImport | PySymbol | None:
- PySymbol if the symbol is found directly in this file
- PyImport if the symbol is found through a wildcard import
- None if the symbol cannot be found
"""
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name):
node = wildcard_import

return node
19 changes: 17 additions & 2 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,28 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
filepath = os.path.join(base_path, filepath)
if file := self.ctx.get_file(filepath):
symbol = file.get_node_by_name(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)
if symbol is None:
if file.get_node_from_wildcard_chain(symbol_name):
return ImportResolution(from_file=file, symbol=None, imports_file=True)
else:
# This is most likely a broken import
return ImportResolution(from_file=file, symbol=None)
else:
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.ctx.get_file(filepath):
symbol = from_file.get_node_by_name(symbol_name)
return ImportResolution(from_file=from_file, symbol=symbol)
if symbol is None:
if from_file.get_node_from_wildcard_chain(symbol_name):
return ImportResolution(from_file=from_file, symbol=None, imports_file=True)
else:
# This is most likely a broken import
return ImportResolution(from_file=from_file, symbol=None)

else:
return ImportResolution(from_file=from_file, symbol=symbol)

# =====[ Case: Can't resolve the import ]=====
if base_path == "":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,204 @@ def func_1():
assert call_site.file == consumer_file


def test_import_resolution_init_wildcard(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly"""
# language=python
content1 = """TEST_CONST=2
foo=9
"""
content2 = """from testdir.test1 import *
bar=foo
test=TEST_CONST"""
content3 = """from testdir import TEST_CONST
test3=TEST_CONST"""
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
file1: SourceFile = codebase.get_file("testdir/test1.py")
file2: SourceFile = codebase.get_file("testdir/__init__.py")
file3: SourceFile = codebase.get_file("test3.py")

symb = file1.get_symbol("TEST_CONST")
test = file2.get_symbol("test")
test3 = file3.get_symbol("test3")
test3_import = file3.get_import("TEST_CONST")

assert len(symb.usages) == 3
assert symb.symbol_usages == [test, test3, test3_import]


def test_import_resolution_wildcard_func(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly"""
# language=python
content1 = """
def foo():
pass
def bar():
pass
"""
content2 = """
from testa import *

foo()
"""

with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase:
testa: SourceFile = codebase.get_file("testa.py")
testb: SourceFile = codebase.get_file("testb.py")

foo = testa.get_symbol("foo")
bar = testa.get_symbol("bar")
assert len(foo.usages) == 1
assert len(foo.call_sites) == 1

assert len(bar.usages) == 0
assert len(bar.call_sites) == 0
assert len(testb.function_calls) == 1


def test_import_resolution_chaining_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python
content1 = """TEST_CONST=2
foo=9
"""
content2 = """from testdir.test1 import *
bar=foo
test=TEST_CONST"""
content3 = """from testdir import *
test3=TEST_CONST"""
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
file1: SourceFile = codebase.get_file("testdir/test1.py")
file2: SourceFile = codebase.get_file("testdir/__init__.py")
file3: SourceFile = codebase.get_file("test3.py")

symb = file1.get_symbol("TEST_CONST")
test = file2.get_symbol("test")
bar = file2.get_symbol("bar")
mid_import = file2.get_import("testdir.test1")
test3 = file3.get_symbol("test3")

assert len(symb.usages) == 2
assert symb.symbol_usages == [test, test3]
assert mid_import.symbol_usages == [test, bar, test3]


def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test/nest/nest2/test1.py": """test_const=5
test_not_used=2
test_used_parent=5
""",
"test/nest/nest2/__init__.py": """from .test1 import *
t1=test_used_parent
""",
"test/nest/__init__.py": """from .nest2 import *""",
"test/__init__.py": """from .nest import *""",
"main.py": """
from test import *
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = deepest_layer.get_symbol("test_const")
test_not_used = deepest_layer.get_symbol("test_not_used")
test_used_parent = deepest_layer.get_symbol("test_used_parent")

assert len(test_const.usages) == 1
assert test_const.usages[0].usage_symbol == main_test
assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test1.py": """
test_const=5
test_not_used=2
test_used_parent=5
""",
"test2.py": """from test1 import *
t1=test_used_parent
""",
"test3.py": """from test2 import *""",
"test4.py": """from test3 import *""",
"main.py": """
from test4 import *
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
furthest_layer: SourceFile = codebase.get_file("test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test2.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = furthest_layer.get_symbol("test_const")
test_not_used = furthest_layer.get_symbol("test_not_used")
test_used_parent = furthest_layer.get_symbol("test_used_parent")

assert len(test_const.usages) == 1
assert test_const.usages[0].usage_symbol == main_test
assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None:
"""Tests that chaining wildcard imports resolves properly"""
# language=python

files = {
"test/nest/nest2/test1.py": """test_const=5
test_not_used=2
test_used_parent=5
""",
"test/nest/nest2/__init__.py": """from .test1 import *
t1=test_used_parent
""",
"test/nest/__init__.py": """from .nest2 import *""",
"test/__init__.py": """from .nest import *""",
"main.py": """
from test import test_const
main_test=test_const
""",
}
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
main: SourceFile = codebase.get_file("main.py")
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")
test_nest: SourceFile = codebase.get_file("test/__init__.py")

main_test = main.get_symbol("main_test")
t1 = parent_file.get_symbol("t1")
test_const = deepest_layer.get_symbol("test_const")
test_not_used = deepest_layer.get_symbol("test_not_used")
test_used_parent = deepest_layer.get_symbol("test_used_parent")

test_const_imp = main.get_import("test_const")

assert len(test_const.usages) == 2
assert test_const.usages[0].usage_symbol == main_test
assert test_const.usages[1].usage_symbol == test_const_imp

assert len(test_not_used.usages) == 0
assert len(test_used_parent.usages) == 1
assert test_used_parent.usages[0].usage_symbol == t1


def test_import_resolution_circular(tmpdir: str) -> None:
"""Tests function.usages returns usages from file imports"""
# language=python
Expand Down Expand Up @@ -367,4 +565,66 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None:
) as codebase:
mainfile: SourceFile = codebase.get_file("file.py")

assert len(mainfile.ctx.edges) == 5
assert len(mainfile.ctx.edges) == 10


def test_import_resolution_init_wildcard_no_dupe(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly and doesn't
result in duplicate usages
"""
# language=python
content1 = """TEST_CONST=2
foo=9
"""
content2 = """from testdir.test1 import *
bar=foo
test=TEST_CONST"""
content3 = """from testdir import TEST_CONST
test3=TEST_CONST"""
content4 = """from testdir import foo
test4=foo"""
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3, "test4.py": content4}) as codebase:
file1: SourceFile = codebase.get_file("testdir/test1.py")
file2: SourceFile = codebase.get_file("testdir/__init__.py")
file3: SourceFile = codebase.get_file("test3.py")

symb = file1.get_symbol("TEST_CONST")
test = file2.get_symbol("test")
test3 = file3.get_symbol("test3")
test3_import = file3.get_import("TEST_CONST")

assert len(symb.usages) == 3
assert symb.symbol_usages == [test, test3, test3_import]


def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None:
"""Tests that named import from a file with wildcard resolves properly and doesn't
result in duplicate usages
"""
# language=python
content1 = """TEST_CONST=2
"""
content2 = """from .file1 import *"""
content3 = """from .dir import *"""
content4 = """from .dir import TEST_CONST
test1=TEST_CONST"""
with get_codebase_session(
tmpdir=tmpdir,
files={
"dir/dir/dir/dir/file1.py": content1,
"dir/dir/dir/dir/__init__.py": content2,
"dir/dir/dir/__init__.py": content3,
"dir/dir/__init__.py": content3,
"dir/__init__.py": content3,
"file2.py": content4,
},
) as codebase:
file1: SourceFile = codebase.get_file("dir/dir/dir/dir/file1.py")
file2: SourceFile = codebase.get_file("file2.py")

symb = file1.get_symbol("TEST_CONST")
test1 = file2.get_symbol("test1")
imp = file2.get_import("TEST_CONST")

assert len(symb.usages) == 2
assert symb.symbol_usages == [test1, imp]