Skip to content

Commit e7c28b6

Browse files
authored
CG-10899: file.add_import (#673)
`file.add_import_from_string` and `file.add_symbol_import` has now been consolidated into one method `file.import` which takes in `Symbol | str` either the string import or the Symbol that needs to be imported and will handle the logic ---------
1 parent 3b3d135 commit e7c28b6

File tree

40 files changed

+545
-272
lines changed

40 files changed

+545
-272
lines changed

codegen-examples/examples/dict_to_schema/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def run(codebase: Codebase):
8484

8585
# Add imports if needed
8686
if needs_imports:
87-
file.add_import_from_import_string("from pydantic import BaseModel")
87+
file.add_import("from pydantic import BaseModel")
8888

8989
if file_modified:
9090
files_modified += 1

codegen-examples/examples/flask_to_fastapi_migration/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def setup_static_files(file):
5757
print(f"📁 Processing file: {file.filepath}")
5858

5959
# Add import for StaticFiles
60-
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
60+
file.add_import("from fastapi.staticfiles import StaticFiles")
6161
print("✅ Added import: from fastapi.staticfiles import StaticFiles")
6262

6363
# Add app.mount for static file handling

codegen-examples/examples/sqlalchemy_soft_delete/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ The codemod processes your codebase in several steps:
5858
```python
5959
def ensure_and_import(file):
6060
if not any("and_" in imp.name for imp in file.imports):
61-
file.add_import_from_import_string("from sqlalchemy import and_")
61+
file.add_import("from sqlalchemy import and_")
6262
```
6363

6464
- Automatically adds required SQLAlchemy imports (`and_`)

codegen-examples/examples/sqlalchemy_soft_delete/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def ensure_and_import(file):
5151
"""Ensure the file has the necessary and_ import."""
5252
if not any("and_" in imp.name for imp in file.imports):
5353
print(f"File {file.filepath} does not import and_. Adding import.")
54-
file.add_import_from_import_string("from sqlalchemy import and_")
54+
file.add_import("from sqlalchemy import and_")
5555

5656

5757
def clone_repo(repo_url: str, repo_path: Path) -> None:

codegen-examples/examples/sqlalchemy_type_annotations/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,16 @@ def run(codebase: Codebase):
100100

101101
# Add necessary imports
102102
if not cls.file.has_import("Mapped"):
103-
cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n")
103+
cls.file.add_import("from sqlalchemy.orm import Mapped\n")
104104

105105
if "Optional" in new_type and not cls.file.has_import("Optional"):
106-
cls.file.add_import_from_import_string("from typing import Optional\n")
106+
cls.file.add_import("from typing import Optional\n")
107107

108108
if "Decimal" in new_type and not cls.file.has_import("Decimal"):
109-
cls.file.add_import_from_import_string("from decimal import Decimal\n")
109+
cls.file.add_import("from decimal import Decimal\n")
110110

111111
if "datetime" in new_type and not cls.file.has_import("datetime"):
112-
cls.file.add_import_from_import_string("from datetime import datetime\n")
112+
cls.file.add_import("from datetime import datetime\n")
113113

114114
if class_modified:
115115
classes_modified += 1

codegen-examples/examples/unittest_to_pytest/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def convert_to_pytest_fixtures(file):
2424
print(f"🔍 Processing file: {file.filepath}")
2525

2626
if not any(imp.name == "pytest" for imp in file.imports):
27-
file.add_import_from_import_string("import pytest")
27+
file.add_import("import pytest")
2828
print(f"➕ Added pytest import to {file.filepath}")
2929

3030
for cls in file.classes:

codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The script automates the entire migration process in a few key steps:
2525

2626
```python
2727
import_str = "import { useQuery, useSuspenseQueries } from '@tanstack/react-query'"
28-
file.add_import_from_import_string(import_str)
28+
file.add_import(import_str)
2929
```
3030

3131
- Uses Codegen's import analysis to add required imports

codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def run(codebase: Codebase):
2626

2727
print(f"Processing {file.filepath}")
2828
# Add the import statement
29-
file.add_import_from_import_string(import_str)
29+
file.add_import(import_str)
3030
file_modified = False
3131

3232
# Iterate through all functions in the file

docs/building-with-codegen/imports.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ for module, imports in module_imports.items():
120120
if len(imports) > 1:
121121
# Create combined import
122122
symbols = [imp.name for imp in imports]
123-
file.add_import_from_import_string(
123+
file.add_import(
124124
f"import {{ {', '.join(symbols)} }} from '{module}'"
125125
)
126126
# Remove old imports

docs/building-with-codegen/react-and-jsx.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,5 @@ for function in codebase.functions:
136136

137137
# Add import if needed
138138
if not file.has_import("NewComponent"):
139-
file.add_symbol_import(new_component)
139+
file.add_import(new_component)
140140
```

docs/tutorials/flask-to-fastapi.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi
119119

120120
```python
121121
# Add StaticFiles import
122-
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
122+
file.add_import("from fastapi.staticfiles import StaticFiles")
123123

124124
# Mount static directory
125125
file.add_symbol_from_source(

docs/tutorials/modularity.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,17 @@ def organize_file_imports(file):
116116
# Add imports back in organized groups
117117
if std_lib_imports:
118118
for imp in std_lib_imports:
119-
file.add_import_from_import_string(imp.source)
119+
file.add_import(imp.source)
120120
file.insert_after_imports("") # Add newline
121121

122122
if third_party_imports:
123123
for imp in third_party_imports:
124-
file.add_import_from_import_string(imp.source)
124+
file.add_import(imp.source)
125125
file.insert_after_imports("") # Add newline
126126

127127
if local_imports:
128128
for imp in local_imports:
129-
file.add_import_from_import_string(imp.source)
129+
file.add_import(imp.source)
130130

131131
# Organize imports in all files
132132
for file in codebase.files:

docs/tutorials/react-modernization.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) =
8282
# Add required imports
8383
file = class_def.file
8484
if not any("useState" in imp.source for imp in file.imports):
85-
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
85+
file.add_import("import { useState, useEffect } from 'react';")
8686
```
8787

8888
## Migrating to Modern Hooks
@@ -100,7 +100,7 @@ for function in codebase.functions:
100100
# Convert withRouter to useNavigate
101101
if call.name == "withRouter":
102102
# Add useNavigate import
103-
function.file.add_import_from_import_string(
103+
function.file.add_import(
104104
"import { useNavigate } from 'react-router-dom';"
105105
)
106106
# Add navigate hook

src/codegen/cli/mcp/resources/system_prompt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2909,7 +2909,7 @@ def validate_data(data: dict) -> bool:
29092909
if len(imports) > 1:
29102910
# Create combined import
29112911
symbols = [imp.name for imp in imports]
2912-
file.add_import_from_import_string(
2912+
file.add_import(
29132913
f"import {{ {', '.join(symbols)} }} from '{module}'"
29142914
)
29152915
# Remove old imports
@@ -5180,7 +5180,7 @@ def build_graph(func, depth=0):
51805180
51815181
# Add import if needed
51825182
if not file.has_import("NewComponent"):
5183-
file.add_symbol_import(new_component)
5183+
file.add_import(new_component)
51845184
```
51855185
51865186
@@ -7316,17 +7316,17 @@ def organize_file_imports(file):
73167316
# Add imports back in organized groups
73177317
if std_lib_imports:
73187318
for imp in std_lib_imports:
7319-
file.add_import_from_import_string(imp.source)
7319+
file.add_import(imp.source)
73207320
file.insert_after_imports("") # Add newline
73217321
73227322
if third_party_imports:
73237323
for imp in third_party_imports:
7324-
file.add_import_from_import_string(imp.source)
7324+
file.add_import(imp.source)
73257325
file.insert_after_imports("") # Add newline
73267326
73277327
if local_imports:
73287328
for imp in local_imports:
7329-
file.add_import_from_import_string(imp.source)
7329+
file.add_import(imp.source)
73307330
73317331
# Organize imports in all files
73327332
for file in codebase.files:
@@ -8593,7 +8593,7 @@ class FeatureFlags:
85938593
# Add required imports
85948594
file = class_def.file
85958595
if not any("useState" in imp.source for imp in file.imports):
8596-
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
8596+
file.add_import("import { useState, useEffect } from 'react';")
85978597
```
85988598
85998599
## Migrating to Modern Hooks
@@ -8611,7 +8611,7 @@ class FeatureFlags:
86118611
# Convert withRouter to useNavigate
86128612
if call.name == "withRouter":
86138613
# Add useNavigate import
8614-
function.file.add_import_from_import_string(
8614+
function.file.add_import(
86158615
"import { useNavigate } from 'react-router-dom';"
86168616
)
86178617
# Add navigate hook
@@ -9813,7 +9813,7 @@ def create_user():
98139813
98149814
```python
98159815
# Add StaticFiles import
9816-
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
9816+
file.add_import("from fastapi.staticfiles import StaticFiles")
98179817
98189818
# Mount static directory
98199819
file.add_symbol_from_source(

src/codegen/sdk/core/class_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,9 @@ def add_attribute(self, attribute: Attribute, include_dependencies: bool = False
378378
file = self.file
379379
for d in deps:
380380
if isinstance(d, Import):
381-
file.add_symbol_import(d.imported_symbol)
381+
file.add_import(d.imported_symbol)
382382
elif isinstance(d, Symbol):
383-
file.add_symbol_import(d)
383+
file.add_import(d)
384384

385385
@property
386386
@noapidoc

src/codegen/sdk/core/file.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -944,62 +944,56 @@ def update_filepath(self, new_filepath: str) -> None:
944944
imp.set_import_module(new_module_name)
945945

946946
@writer
947-
def add_symbol_import(
948-
self,
949-
symbol: Symbol,
950-
alias: str | None = None,
951-
import_type: ImportType = ImportType.UNKNOWN,
952-
is_type_import: bool = False,
953-
) -> Import | None:
954-
"""Adds an import to a file for a given symbol.
955-
956-
This method adds an import statement to the file for a specified symbol. If an import for the
957-
symbol already exists, it returns the existing import instead of creating a new one.
947+
def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None:
948+
"""Adds an import to the file.
958949
959-
Args:
960-
symbol (Symbol): The symbol to import.
961-
alias (str | None): Optional alias for the imported symbol. Defaults to None.
962-
import_type (ImportType): The type of import to use. Defaults to ImportType.UNKNOWN.
963-
is_type_import (bool): Whether this is a type-only import. Defaults to False.
964-
965-
Returns:
966-
Import | None: The existing import for the symbol or None if it was added.
967-
"""
968-
imports = self.imports
969-
match = next((x for x in imports if x.imported_symbol == symbol), None)
970-
if match:
971-
return match
972-
973-
import_string = symbol.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
974-
self.add_import_from_import_string(import_string)
975-
976-
@writer(commit=False)
977-
def add_import_from_import_string(self, import_string: str) -> None:
978-
"""Adds import to the file from a string representation of an import statement.
979-
980-
This method adds a new import statement to the file based on its string representation.
950+
This method adds an import statement to the file. It can handle both string imports and symbol imports.
981951
If the import already exists in the file, or is pending to be added, it won't be added again.
982952
If there are existing imports, the new import will be added before the first import,
983953
otherwise it will be added at the beginning of the file.
984954
985955
Args:
986-
import_string (str): The string representation of the import statement to add.
956+
imp (Symbol | str): Either a Symbol to import or a string representation of an import statement.
957+
alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None.
958+
import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN.
959+
is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False.
987960
988961
Returns:
989-
None
962+
Import | None: The existing import for the symbol if found, otherwise None.
990963
"""
991-
if any(import_string.strip() in imp.source for imp in self.imports):
992-
return
964+
# Handle Symbol imports
965+
if isinstance(imp, str):
966+
# Handle string imports
967+
import_string = imp
968+
# Check for duplicate imports
969+
if any(import_string.strip() in imp.source for imp in self.imports):
970+
return None
971+
else:
972+
# Check for existing imports of this symbol
973+
imports = self.imports
974+
match = next((x for x in imports if x.imported_symbol == imp), None)
975+
if match:
976+
return match
977+
978+
# Convert symbol to import string
979+
import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
980+
993981
if import_string.strip() in self._pending_imports:
994982
# Don't add the import string if it will already be added by another symbol
995-
return
983+
return None
984+
985+
# Add to pending imports and setup undo
996986
self._pending_imports.add(import_string.strip())
997987
self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear())
988+
989+
# Insert the import at the appropriate location
998990
if self.imports:
999991
self.imports[0].insert_before(import_string, priority=1)
1000992
else:
1001993
self.insert_before(import_string, priority=1)
1002994

995+
return None
996+
1003997
@writer
1004998
def add_symbol_from_source(self, source: str) -> None:
1005999
"""Adds a symbol to a file from a string representation.

src/codegen/sdk/core/symbol.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -329,19 +329,19 @@ def _move_to_file(
329329
# =====[ Imports - copy over ]=====
330330
elif isinstance(dep, Import):
331331
if dep.imported_symbol:
332-
file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
332+
file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
333333
else:
334-
file.add_import_from_import_string(dep.source)
334+
file.add_import(imp=dep.source)
335335
else:
336336
for dep in self.dependencies:
337337
# =====[ Symbols - add back edge ]=====
338338
if isinstance(dep, Symbol) and dep.is_top_level:
339-
file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
339+
file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
340340
elif isinstance(dep, Import):
341341
if dep.imported_symbol:
342-
file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
342+
file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
343343
else:
344-
file.add_import_from_import_string(dep.source)
344+
file.add_import(imp=dep.source)
345345

346346
# =====[ Make a new symbol in the new file ]=====
347347
file.add_symbol(self)
@@ -364,7 +364,7 @@ def _move_to_file(
364364
# Here, we will add a "back edge" to the old file importing the symbol
365365
elif strategy == "add_back_edge":
366366
if is_used_in_file or any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages):
367-
self.file.add_import_from_import_string(import_line)
367+
self.file.add_import(imp=import_line)
368368
# Delete the original symbol
369369
self.remove()
370370

@@ -374,7 +374,7 @@ def _move_to_file(
374374
for usage in self.usages:
375375
if isinstance(usage.usage_symbol, Import) and usage.usage_symbol.file != file:
376376
# Add updated import
377-
usage.usage_symbol.file.add_import_from_import_string(import_line)
377+
usage.usage_symbol.file.add_import(import_line)
378378
usage.usage_symbol.remove()
379379
elif usage.usage_type == UsageType.CHAINED:
380380
# Update all previous usages of import * to the new import name
@@ -383,11 +383,11 @@ def _move_to_file(
383383
usage.match.get_name().edit(self.name)
384384
if isinstance(usage.match, ChainedAttribute):
385385
usage.match.edit(self.name)
386-
usage.usage_symbol.file.add_import_from_import_string(import_line)
386+
usage.usage_symbol.file.add_import(imp=import_line)
387387

388388
# Add the import to the original file
389389
if is_used_in_file:
390-
self.file.add_import_from_import_string(import_line)
390+
self.file.add_import(imp=import_line)
391391
# Delete the original symbol
392392
self.remove()
393393

0 commit comments

Comments
 (0)