Skip to content

Commit 66ad2c6

Browse files
spezifantMatthias BarteltMatthias Bartelttomcodgenchristinewangcw
authored
feat: Allow custom overrides for package resolving and optional sys.path support (#601)
# Motivation Certain build systems or configurations have Python packages located in a non-default path. Those are not found by _Codegen_ per default. # Content _Codegen_ now considers the `PYTHONPATH` environment variable when resolving packages. # Testing Added unit tests which resolve packages only when the `PYTHONPATH` is set in a correct way. # Please check the following before marking your PR as ready for review - [x] I have added tests for my changes - [x] I have updated the documentation or added new documentation as needed --------- Co-authored-by: Matthias Bartelt <[email protected]> Co-authored-by: Matthias Bartelt <[email protected]> Co-authored-by: tomcodgen <[email protected]> Co-authored-by: Christine Wang <[email protected]>
1 parent ac86411 commit 66ad2c6

File tree

7 files changed

+282
-2
lines changed

7 files changed

+282
-2
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ uv run pytest tests/unit -n auto
5151
uv run pytest tests/integration/codemod/test_codemods.py -n auto
5252
```
5353

54+
> [!TIP]
55+
>
56+
> - If on Linux the error `OSError: [Errno 24] Too many open files` appears then you might want to increase your _ulimit_
57+
5458
## Pull Request Process
5559

5660
1. Fork the repository and create your branch from `develop`.

docs/building-with-codegen/imports.mdx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ print(f"From file: {import_stmt.from_file.filepath}")
6969
print(f"To file: {import_stmt.to_file.filepath}")
7070
```
7171

72+
<Note>
73+
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
74+
packages.
75+
</Note>
76+
7277
## Working with External Modules
7378

7479
You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:

src/codegen/cli/mcp/resources/system_prompt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,6 +2858,11 @@ def validate_data(data: dict) -> bool:
28582858
print(f"To file: {import_stmt.to_file.filepath}")
28592859
```
28602860
2861+
<Note>
2862+
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
2863+
packages.
2864+
</Note>
2865+
28612866
## Working with External Modules
28622867
28632868
You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:

src/codegen/configs/models/codebase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __init__(self, prefix: str = "CODEBASE", *args, **kwargs) -> None:
1717
disable_graph: bool = False
1818
disable_file_parse: bool = False
1919
generics: bool = True
20+
import_resolution_paths: list[str] = Field(default_factory=lambda: [])
2021
import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {})
22+
py_resolve_syspath: bool = False
2123
ts_dependency_manager: bool = False
2224
ts_language_engine: bool = False
2325
v8_ts_engine: bool = False

src/codegen/sdk/python/import_resolution.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
import sys
45
from typing import TYPE_CHECKING
56

67
from codegen.sdk.core.autocommit import reader
@@ -104,6 +105,15 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
104105
base_path,
105106
module_source.replace(".", "/") + "/" + symbol_name + ".py",
106107
)
108+
109+
# =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]=====
110+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
111+
# Handle resolve overrides first if both is set
112+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
113+
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
114+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
115+
116+
# =====[ Default path ]=====
107117
if file := self.ctx.get_file(filepath):
108118
return ImportResolution(from_file=file, symbol=None, imports_file=True)
109119

@@ -113,8 +123,16 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
113123
# You can't do `from a.b.c import foo` => `foo.utils.x` right now since `foo` is just a file...
114124
return ImportResolution(from_file=file, symbol=None, imports_file=True)
115125

116-
# =====[ Check if `module.py` file exists in the graph ]=====
126+
# =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
117127
filepath = module_source.replace(".", "/") + ".py"
128+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
129+
# Handle resolve overrides first if both is set
130+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
131+
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
132+
symbol = file.get_node_by_name(symbol_name)
133+
return ImportResolution(from_file=file, symbol=symbol)
134+
135+
# =====[ Check if `module.py` file exists in the graph ]=====
118136
filepath = os.path.join(base_path, filepath)
119137
if file := self.ctx.get_file(filepath):
120138
symbol = file.get_node_by_name(symbol_name)
@@ -163,6 +181,20 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
163181
# ext = ExternalModule.from_import(self)
164182
# return ImportResolution(symbol=ext)
165183

184+
@noapidoc
185+
@reader
186+
def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) -> SourceFile | None:
187+
"""Check if a certain file import can be found within a set sys.path
188+
189+
Returns either None or the SourceFile.
190+
"""
191+
for resolve_path in resolve_paths:
192+
filepath_new: str = os.path.join(resolve_path, filepath)
193+
if file := self.ctx.get_file(filepath_new):
194+
return file
195+
196+
return None
197+
166198
@noapidoc
167199
@reader
168200
def _relative_to_absolute_import(self, relative_import: str) -> str:

src/codegen/sdk/system-prompt.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,6 +2879,11 @@ print(f"From file: {import_stmt.from_file.filepath}")
28792879
print(f"To file: {import_stmt.to_file.filepath}")
28802880
```
28812881

2882+
<Note>
2883+
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
2884+
packages.
2885+
</Note>
2886+
28822887
## Working with External Modules
28832888

28842889
You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:

tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py

Lines changed: 228 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from typing import TYPE_CHECKING
23

34
from codegen.sdk.codebase.factory.get_session import get_codebase_session
@@ -191,7 +192,7 @@ def update():
191192
"consumer.py": """
192193
from a.b.c import src as operations
193194
194-
def func_1():
195+
def func():
195196
operations.update()
196197
""",
197198
},
@@ -215,6 +216,232 @@ def func_1():
215216
assert call_site.file == consumer_file
216217

217218

219+
def test_import_resolution_file_syspath_inactive(tmpdir: str, monkeypatch) -> None:
220+
"""Tests function.usages returns usages from file imports"""
221+
# language=python
222+
with get_codebase_session(
223+
tmpdir,
224+
files={
225+
"a/b/c/src.py": """
226+
def update():
227+
pass
228+
""",
229+
"consumer.py": """
230+
from b.c import src as operations
231+
232+
def func():
233+
operations.update()
234+
""",
235+
},
236+
) as codebase:
237+
src_file: SourceFile = codebase.get_file("a/b/c/src.py")
238+
consumer_file: SourceFile = codebase.get_file("consumer.py")
239+
240+
# Disable resolution via sys.path
241+
codebase.ctx.config.py_resolve_syspath = False
242+
243+
# =====[ Imports cannot be found without sys.path being set and not active ]=====
244+
assert len(consumer_file.imports) == 1
245+
src_import: Import = consumer_file.imports[0]
246+
src_import_resolution: ImportResolution = src_import.resolve_import()
247+
assert src_import_resolution is None
248+
249+
# Modify sys.path for this test only
250+
monkeypatch.syspath_prepend("a")
251+
252+
# =====[ Imports cannot be found with sys.path set but not active ]=====
253+
src_import_resolution = src_import.resolve_import()
254+
assert src_import_resolution is None
255+
256+
257+
def test_import_resolution_file_syspath_active(tmpdir: str, monkeypatch) -> None:
258+
"""Tests function.usages returns usages from file imports"""
259+
# language=python
260+
with get_codebase_session(
261+
tmpdir,
262+
files={
263+
"a/b/c/src.py": """
264+
def update():
265+
pass
266+
""",
267+
"consumer.py": """
268+
from b.c import src as operations
269+
270+
def func():
271+
operations.update()
272+
""",
273+
},
274+
) as codebase:
275+
src_file: SourceFile = codebase.get_file("a/b/c/src.py")
276+
consumer_file: SourceFile = codebase.get_file("consumer.py")
277+
278+
# Enable resolution via sys.path
279+
codebase.ctx.config.py_resolve_syspath = True
280+
281+
# =====[ Imports cannot be found without sys.path being set ]=====
282+
assert len(consumer_file.imports) == 1
283+
src_import: Import = consumer_file.imports[0]
284+
src_import_resolution: ImportResolution = src_import.resolve_import()
285+
assert src_import_resolution is None
286+
287+
# Modify sys.path for this test only
288+
monkeypatch.syspath_prepend("a")
289+
290+
# =====[ Imports can be found with sys.path set and active ]=====
291+
codebase.ctx.config.py_resolve_syspath = True
292+
src_import_resolution = src_import.resolve_import()
293+
assert src_import_resolution
294+
assert src_import_resolution.from_file is src_file
295+
assert src_import_resolution.imports_file is True
296+
297+
298+
def test_import_resolution_file_custom_resolve_path(tmpdir: str) -> None:
299+
"""Tests function.usages returns usages from file imports"""
300+
# language=python
301+
with get_codebase_session(
302+
tmpdir,
303+
files={
304+
"a/b/c/src.py": """
305+
def update():
306+
pass
307+
""",
308+
"consumer.py": """
309+
from b.c import src as operations
310+
from c import src as operations2
311+
312+
def func():
313+
operations.update()
314+
""",
315+
},
316+
) as codebase:
317+
src_file: SourceFile = codebase.get_file("a/b/c/src.py")
318+
consumer_file: SourceFile = codebase.get_file("consumer.py")
319+
320+
# =====[ Imports cannot be found without custom resolve path being set ]=====
321+
assert len(consumer_file.imports) == 2
322+
src_import: Import = consumer_file.imports[0]
323+
src_import_resolution: ImportResolution = src_import.resolve_import()
324+
assert src_import_resolution is None
325+
326+
# =====[ Imports cannot be found with custom resolve path set to invalid path ]=====
327+
codebase.ctx.config.import_resolution_paths = ["x"]
328+
src_import_resolution = src_import.resolve_import()
329+
assert src_import_resolution is None
330+
331+
# =====[ Imports can be found with custom resolve path set ]=====
332+
codebase.ctx.config.import_resolution_paths = ["a"]
333+
src_import_resolution = src_import.resolve_import()
334+
assert src_import_resolution
335+
assert src_import_resolution.from_file is src_file
336+
assert src_import_resolution.imports_file is True
337+
338+
# =====[ Imports can be found with custom resolve multi-path set ]=====
339+
src_import = consumer_file.imports[1]
340+
codebase.ctx.config.import_resolution_paths = ["a/b"]
341+
src_import_resolution = src_import.resolve_import()
342+
assert src_import_resolution
343+
assert src_import_resolution.from_file is src_file
344+
assert src_import_resolution.imports_file is True
345+
346+
347+
def test_import_resolution_file_custom_resolve_and_syspath_precedence(tmpdir: str, monkeypatch) -> None:
348+
"""Tests function.usages returns usages from file imports"""
349+
# language=python
350+
with get_codebase_session(
351+
tmpdir,
352+
files={
353+
"a/c/src.py": """
354+
def update1():
355+
pass
356+
""",
357+
"a/b/c/src.py": """
358+
def update2():
359+
pass
360+
""",
361+
"consumer.py": """
362+
from c import src as operations
363+
364+
def func():
365+
operations.update2()
366+
""",
367+
},
368+
) as codebase:
369+
src_file: SourceFile = codebase.get_file("a/b/c/src.py")
370+
consumer_file: SourceFile = codebase.get_file("consumer.py")
371+
372+
# Ensure we don't have overrites and enable syspath resolution
373+
codebase.ctx.config.import_resolution_paths = []
374+
codebase.ctx.config.py_resolve_syspath = True
375+
376+
# =====[ Import with sys.path set can be found ]=====
377+
assert len(consumer_file.imports) == 1
378+
# Modify sys.path for this test only
379+
monkeypatch.syspath_prepend("a")
380+
src_import: Import = consumer_file.imports[0]
381+
src_import_resolution = src_import.resolve_import()
382+
assert src_import_resolution
383+
assert src_import_resolution.from_file.file_path == "a/c/src.py"
384+
385+
# =====[ Imports can be found with custom resolve over sys.path ]=====
386+
codebase.ctx.config.import_resolution_paths = ["a/b"]
387+
src_import_resolution = src_import.resolve_import()
388+
assert src_import_resolution
389+
assert src_import_resolution.from_file is src_file
390+
assert src_import_resolution.imports_file is True
391+
392+
393+
def test_import_resolution_default_conflicts_overrite(tmpdir: str, monkeypatch) -> None:
394+
"""Tests function.usages returns usages from file imports"""
395+
# language=python
396+
with get_codebase_session(
397+
tmpdir,
398+
files={
399+
"a/src.py": """
400+
def update1():
401+
pass
402+
""",
403+
"b/a/src.py": """
404+
def update2():
405+
pass
406+
""",
407+
"consumer.py": """
408+
from a import src as operations
409+
410+
def func():
411+
operations.update2()
412+
""",
413+
},
414+
) as codebase:
415+
src_file: SourceFile = codebase.get_file("a/src.py")
416+
src_file_overrite: SourceFile = codebase.get_file("b/a/src.py")
417+
consumer_file: SourceFile = codebase.get_file("consumer.py")
418+
419+
# Ensure we don't have overrites and enable syspath resolution
420+
codebase.ctx.config.import_resolution_paths = []
421+
codebase.ctx.config.py_resolve_syspath = True
422+
423+
# =====[ Default import works ]=====
424+
assert len(consumer_file.imports) == 1
425+
src_import: Import = consumer_file.imports[0]
426+
src_import_resolution = src_import.resolve_import()
427+
assert src_import_resolution
428+
assert src_import_resolution.from_file is src_file
429+
430+
# =====[ Sys.path overrite has precedence ]=====
431+
monkeypatch.syspath_prepend("b")
432+
src_import_resolution = src_import.resolve_import()
433+
assert src_import_resolution
434+
assert src_import_resolution.from_file is not src_file
435+
assert src_import_resolution.from_file is src_file_overrite
436+
437+
# =====[ Custom overrite has precedence ]=====
438+
codebase.ctx.config.import_resolution_paths = ["b"]
439+
src_import_resolution = src_import.resolve_import()
440+
assert src_import_resolution
441+
assert src_import_resolution.from_file is not src_file
442+
assert src_import_resolution.from_file is src_file_overrite
443+
444+
218445
def test_import_resolution_init_wildcard(tmpdir: str) -> None:
219446
"""Tests that named import from a file with wildcard resolves properly"""
220447
# language=python

0 commit comments

Comments
 (0)