Skip to content

Commit 997e36f

Browse files
tomcodgenfrainfreezetkfoss
authored
[CG-10888] fix: Wildcard resolution (#612)
Co-authored-by: tomcodegen <[email protected]> Co-authored-by: tomcodgen <[email protected]>
1 parent 5fab64a commit 997e36f

File tree

4 files changed

+346
-3
lines changed

4 files changed

+346
-3
lines changed

src/codegen/sdk/core/import_resolution.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo
324324
"""Returns the symbol directly being imported, including an indirect import and an External
325325
Module.
326326
"""
327+
from codegen.sdk.python.file import PyFile
327328
from codegen.sdk.typescript.file import TSFile
328329

329330
symbol = next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.IMPORT_SYMBOL_RESOLUTION, sort=False)), None)
@@ -341,6 +342,14 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo
341342
if self.import_type == ImportType.NAMED_EXPORT:
342343
if export := symbol.valid_import_names.get(name, None):
343344
return export
345+
elif resolve_exports and isinstance(symbol, PyFile):
346+
name = self.symbol_name.source if self.symbol_name else ""
347+
if self.import_type == ImportType.NAMED_EXPORT:
348+
if symbol.name == name:
349+
return symbol
350+
if imp := symbol.valid_import_names.get(name, None):
351+
return imp
352+
344353
if symbol is not self:
345354
return symbol
346355

@@ -632,6 +641,11 @@ def _compute_dependencies(self, *args, **kwargs) -> None:
632641
# if used_frame.parent_frame:
633642
# used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx)
634643
# else:
644+
if isinstance(self, Import) and self.import_type == ImportType.NAMED_EXPORT:
645+
# It could be a wildcard import downstream, hence we have to pop the cache
646+
if file := self.from_file:
647+
file.invalidate()
648+
635649
for used_frame in self.resolved_type_frames:
636650
if used_frame.parent_frame:
637651
used_frame.parent_frame.add_usage(self._unique_node, UsageKind.IMPORTED, self, self.ctx)

src/codegen/sdk/python/file.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,57 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P
197197
ret[file.name] = file
198198
return ret
199199
return super().valid_import_names
200+
201+
@noapidoc
202+
def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None:
203+
"""Recursively searches for a symbol through wildcard import chains.
204+
205+
Attempts to find a symbol by name in the current file, and if not found, recursively searches
206+
through any wildcard imports (from x import *) to find the symbol in imported modules.
207+
208+
Args:
209+
symbol_name (str): The name of the symbol to search for.
210+
211+
Returns:
212+
PySymbol | None: The found symbol if it exists in this file or any of its wildcard
213+
imports, None otherwise.
214+
"""
215+
node = None
216+
if node := self.get_node_by_name(symbol_name):
217+
return node
218+
219+
if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
220+
for wildcard_import in wildcard_imports:
221+
if imp_resolution := wildcard_import.resolve_import():
222+
node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name)
223+
224+
return node
225+
226+
@noapidoc
227+
def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None:
228+
"""Finds the wildcard import that resolves a given symbol name.
229+
230+
Searches for a symbol by name, first in the current file, then through wildcard imports.
231+
Unlike get_node_from_wildcard_chain, this returns the wildcard import that contains
232+
the symbol rather than the symbol itself.
233+
234+
Args:
235+
symbol_name (str): The name of the symbol to search for.
236+
237+
Returns:
238+
PyImport | PySymbol | None:
239+
- PySymbol if the symbol is found directly in this file
240+
- PyImport if the symbol is found through a wildcard import
241+
- None if the symbol cannot be found
242+
"""
243+
node = None
244+
if node := self.get_node_by_name(symbol_name):
245+
return node
246+
247+
if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
248+
for wildcard_import in wildcard_imports:
249+
if imp_resolution := wildcard_import.resolve_import():
250+
if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name):
251+
node = wildcard_import
252+
253+
return node

src/codegen/sdk/python/import_resolution.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,28 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
118118
filepath = os.path.join(base_path, filepath)
119119
if file := self.ctx.get_file(filepath):
120120
symbol = file.get_node_by_name(symbol_name)
121-
return ImportResolution(from_file=file, symbol=symbol)
121+
if symbol is None:
122+
if file.get_node_from_wildcard_chain(symbol_name):
123+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
124+
else:
125+
# This is most likely a broken import
126+
return ImportResolution(from_file=file, symbol=None)
127+
else:
128+
return ImportResolution(from_file=file, symbol=symbol)
122129

123130
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
124131
filepath = filepath.replace(".py", "/__init__.py")
125132
if from_file := self.ctx.get_file(filepath):
126133
symbol = from_file.get_node_by_name(symbol_name)
127-
return ImportResolution(from_file=from_file, symbol=symbol)
134+
if symbol is None:
135+
if from_file.get_node_from_wildcard_chain(symbol_name):
136+
return ImportResolution(from_file=from_file, symbol=None, imports_file=True)
137+
else:
138+
# This is most likely a broken import
139+
return ImportResolution(from_file=from_file, symbol=None)
140+
141+
else:
142+
return ImportResolution(from_file=from_file, symbol=symbol)
128143

129144
# =====[ Case: Can't resolve the import ]=====
130145
if base_path == "":

tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py

Lines changed: 261 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,204 @@ def func_1():
215215
assert call_site.file == consumer_file
216216

217217

218+
def test_import_resolution_init_wildcard(tmpdir: str) -> None:
219+
"""Tests that named import from a file with wildcard resolves properly"""
220+
# language=python
221+
content1 = """TEST_CONST=2
222+
foo=9
223+
"""
224+
content2 = """from testdir.test1 import *
225+
bar=foo
226+
test=TEST_CONST"""
227+
content3 = """from testdir import TEST_CONST
228+
test3=TEST_CONST"""
229+
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
230+
file1: SourceFile = codebase.get_file("testdir/test1.py")
231+
file2: SourceFile = codebase.get_file("testdir/__init__.py")
232+
file3: SourceFile = codebase.get_file("test3.py")
233+
234+
symb = file1.get_symbol("TEST_CONST")
235+
test = file2.get_symbol("test")
236+
test3 = file3.get_symbol("test3")
237+
test3_import = file3.get_import("TEST_CONST")
238+
239+
assert len(symb.usages) == 3
240+
assert symb.symbol_usages == [test, test3, test3_import]
241+
242+
243+
def test_import_resolution_wildcard_func(tmpdir: str) -> None:
244+
"""Tests that named import from a file with wildcard resolves properly"""
245+
# language=python
246+
content1 = """
247+
def foo():
248+
pass
249+
def bar():
250+
pass
251+
"""
252+
content2 = """
253+
from testa import *
254+
255+
foo()
256+
"""
257+
258+
with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase:
259+
testa: SourceFile = codebase.get_file("testa.py")
260+
testb: SourceFile = codebase.get_file("testb.py")
261+
262+
foo = testa.get_symbol("foo")
263+
bar = testa.get_symbol("bar")
264+
assert len(foo.usages) == 1
265+
assert len(foo.call_sites) == 1
266+
267+
assert len(bar.usages) == 0
268+
assert len(bar.call_sites) == 0
269+
assert len(testb.function_calls) == 1
270+
271+
272+
def test_import_resolution_chaining_wildcards(tmpdir: str) -> None:
273+
"""Tests that chaining wildcard imports resolves properly"""
274+
# language=python
275+
content1 = """TEST_CONST=2
276+
foo=9
277+
"""
278+
content2 = """from testdir.test1 import *
279+
bar=foo
280+
test=TEST_CONST"""
281+
content3 = """from testdir import *
282+
test3=TEST_CONST"""
283+
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase:
284+
file1: SourceFile = codebase.get_file("testdir/test1.py")
285+
file2: SourceFile = codebase.get_file("testdir/__init__.py")
286+
file3: SourceFile = codebase.get_file("test3.py")
287+
288+
symb = file1.get_symbol("TEST_CONST")
289+
test = file2.get_symbol("test")
290+
bar = file2.get_symbol("bar")
291+
mid_import = file2.get_import("testdir.test1")
292+
test3 = file3.get_symbol("test3")
293+
294+
assert len(symb.usages) == 2
295+
assert symb.symbol_usages == [test, test3]
296+
assert mid_import.symbol_usages == [test, bar, test3]
297+
298+
299+
def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None:
300+
"""Tests that chaining wildcard imports resolves properly"""
301+
# language=python
302+
303+
files = {
304+
"test/nest/nest2/test1.py": """test_const=5
305+
test_not_used=2
306+
test_used_parent=5
307+
""",
308+
"test/nest/nest2/__init__.py": """from .test1 import *
309+
t1=test_used_parent
310+
""",
311+
"test/nest/__init__.py": """from .nest2 import *""",
312+
"test/__init__.py": """from .nest import *""",
313+
"main.py": """
314+
from test import *
315+
main_test=test_const
316+
""",
317+
}
318+
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
319+
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
320+
main: SourceFile = codebase.get_file("main.py")
321+
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")
322+
323+
main_test = main.get_symbol("main_test")
324+
t1 = parent_file.get_symbol("t1")
325+
test_const = deepest_layer.get_symbol("test_const")
326+
test_not_used = deepest_layer.get_symbol("test_not_used")
327+
test_used_parent = deepest_layer.get_symbol("test_used_parent")
328+
329+
assert len(test_const.usages) == 1
330+
assert test_const.usages[0].usage_symbol == main_test
331+
assert len(test_not_used.usages) == 0
332+
assert len(test_used_parent.usages) == 1
333+
assert test_used_parent.usages[0].usage_symbol == t1
334+
335+
336+
def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None:
337+
"""Tests that chaining wildcard imports resolves properly"""
338+
# language=python
339+
340+
files = {
341+
"test1.py": """
342+
test_const=5
343+
test_not_used=2
344+
test_used_parent=5
345+
""",
346+
"test2.py": """from test1 import *
347+
t1=test_used_parent
348+
""",
349+
"test3.py": """from test2 import *""",
350+
"test4.py": """from test3 import *""",
351+
"main.py": """
352+
from test4 import *
353+
main_test=test_const
354+
""",
355+
}
356+
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
357+
furthest_layer: SourceFile = codebase.get_file("test1.py")
358+
main: SourceFile = codebase.get_file("main.py")
359+
parent_file: SourceFile = codebase.get_file("test2.py")
360+
361+
main_test = main.get_symbol("main_test")
362+
t1 = parent_file.get_symbol("t1")
363+
test_const = furthest_layer.get_symbol("test_const")
364+
test_not_used = furthest_layer.get_symbol("test_not_used")
365+
test_used_parent = furthest_layer.get_symbol("test_used_parent")
366+
367+
assert len(test_const.usages) == 1
368+
assert test_const.usages[0].usage_symbol == main_test
369+
assert len(test_not_used.usages) == 0
370+
assert len(test_used_parent.usages) == 1
371+
assert test_used_parent.usages[0].usage_symbol == t1
372+
373+
374+
def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None:
375+
"""Tests that chaining wildcard imports resolves properly"""
376+
# language=python
377+
378+
files = {
379+
"test/nest/nest2/test1.py": """test_const=5
380+
test_not_used=2
381+
test_used_parent=5
382+
""",
383+
"test/nest/nest2/__init__.py": """from .test1 import *
384+
t1=test_used_parent
385+
""",
386+
"test/nest/__init__.py": """from .nest2 import *""",
387+
"test/__init__.py": """from .nest import *""",
388+
"main.py": """
389+
from test import test_const
390+
main_test=test_const
391+
""",
392+
}
393+
with get_codebase_session(tmpdir=tmpdir, files=files) as codebase:
394+
deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py")
395+
main: SourceFile = codebase.get_file("main.py")
396+
parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py")
397+
test_nest: SourceFile = codebase.get_file("test/__init__.py")
398+
399+
main_test = main.get_symbol("main_test")
400+
t1 = parent_file.get_symbol("t1")
401+
test_const = deepest_layer.get_symbol("test_const")
402+
test_not_used = deepest_layer.get_symbol("test_not_used")
403+
test_used_parent = deepest_layer.get_symbol("test_used_parent")
404+
405+
test_const_imp = main.get_import("test_const")
406+
407+
assert len(test_const.usages) == 2
408+
assert test_const.usages[0].usage_symbol == main_test
409+
assert test_const.usages[1].usage_symbol == test_const_imp
410+
411+
assert len(test_not_used.usages) == 0
412+
assert len(test_used_parent.usages) == 1
413+
assert test_used_parent.usages[0].usage_symbol == t1
414+
415+
218416
def test_import_resolution_circular(tmpdir: str) -> None:
219417
"""Tests function.usages returns usages from file imports"""
220418
# language=python
@@ -367,4 +565,66 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None:
367565
) as codebase:
368566
mainfile: SourceFile = codebase.get_file("file.py")
369567

370-
assert len(mainfile.ctx.edges) == 5
568+
assert len(mainfile.ctx.edges) == 10
569+
570+
571+
def test_import_resolution_init_wildcard_no_dupe(tmpdir: str) -> None:
572+
"""Tests that named import from a file with wildcard resolves properly and doesn't
573+
result in duplicate usages
574+
"""
575+
# language=python
576+
content1 = """TEST_CONST=2
577+
foo=9
578+
"""
579+
content2 = """from testdir.test1 import *
580+
bar=foo
581+
test=TEST_CONST"""
582+
content3 = """from testdir import TEST_CONST
583+
test3=TEST_CONST"""
584+
content4 = """from testdir import foo
585+
test4=foo"""
586+
with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3, "test4.py": content4}) as codebase:
587+
file1: SourceFile = codebase.get_file("testdir/test1.py")
588+
file2: SourceFile = codebase.get_file("testdir/__init__.py")
589+
file3: SourceFile = codebase.get_file("test3.py")
590+
591+
symb = file1.get_symbol("TEST_CONST")
592+
test = file2.get_symbol("test")
593+
test3 = file3.get_symbol("test3")
594+
test3_import = file3.get_import("TEST_CONST")
595+
596+
assert len(symb.usages) == 3
597+
assert symb.symbol_usages == [test, test3, test3_import]
598+
599+
600+
def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None:
601+
"""Tests that named import from a file with wildcard resolves properly and doesn't
602+
result in duplicate usages
603+
"""
604+
# language=python
605+
content1 = """TEST_CONST=2
606+
"""
607+
content2 = """from .file1 import *"""
608+
content3 = """from .dir import *"""
609+
content4 = """from .dir import TEST_CONST
610+
test1=TEST_CONST"""
611+
with get_codebase_session(
612+
tmpdir=tmpdir,
613+
files={
614+
"dir/dir/dir/dir/file1.py": content1,
615+
"dir/dir/dir/dir/__init__.py": content2,
616+
"dir/dir/dir/__init__.py": content3,
617+
"dir/dir/__init__.py": content3,
618+
"dir/__init__.py": content3,
619+
"file2.py": content4,
620+
},
621+
) as codebase:
622+
file1: SourceFile = codebase.get_file("dir/dir/dir/dir/file1.py")
623+
file2: SourceFile = codebase.get_file("file2.py")
624+
625+
symb = file1.get_symbol("TEST_CONST")
626+
test1 = file2.get_symbol("test1")
627+
imp = file2.get_import("TEST_CONST")
628+
629+
assert len(symb.usages) == 2
630+
assert symb.symbol_usages == [test1, imp]

0 commit comments

Comments
 (0)