Skip to content

CG-10899: file.add_import #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen-examples/examples/dict_to_schema/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(codebase: Codebase):

# Add imports if needed
if needs_imports:
file.add_import_from_import_string("from pydantic import BaseModel")
file.add_import("from pydantic import BaseModel")

if file_modified:
files_modified += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def setup_static_files(file):
print(f"📁 Processing file: {file.filepath}")

# Add import for StaticFiles
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")
print("✅ Added import: from fastapi.staticfiles import StaticFiles")

# Add app.mount for static file handling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ The codemod processes your codebase in several steps:
```python
def ensure_and_import(file):
if not any("and_" in imp.name for imp in file.imports):
file.add_import_from_import_string("from sqlalchemy import and_")
file.add_import("from sqlalchemy import and_")
```

- Automatically adds required SQLAlchemy imports (`and_`)
Expand Down
2 changes: 1 addition & 1 deletion codegen-examples/examples/sqlalchemy_soft_delete/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ensure_and_import(file):
"""Ensure the file has the necessary and_ import."""
if not any("and_" in imp.name for imp in file.imports):
print(f"File {file.filepath} does not import and_. Adding import.")
file.add_import_from_import_string("from sqlalchemy import and_")
file.add_import("from sqlalchemy import and_")


def clone_repo(repo_url: str, repo_path: Path) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def run(codebase: Codebase):

# Add necessary imports
if not cls.file.has_import("Mapped"):
cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n")
cls.file.add_import("from sqlalchemy.orm import Mapped\n")

if "Optional" in new_type and not cls.file.has_import("Optional"):
cls.file.add_import_from_import_string("from typing import Optional\n")
cls.file.add_import("from typing import Optional\n")

if "Decimal" in new_type and not cls.file.has_import("Decimal"):
cls.file.add_import_from_import_string("from decimal import Decimal\n")
cls.file.add_import("from decimal import Decimal\n")

if "datetime" in new_type and not cls.file.has_import("datetime"):
cls.file.add_import_from_import_string("from datetime import datetime\n")
cls.file.add_import("from datetime import datetime\n")

if class_modified:
classes_modified += 1
Expand Down
2 changes: 1 addition & 1 deletion codegen-examples/examples/unittest_to_pytest/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def convert_to_pytest_fixtures(file):
print(f"🔍 Processing file: {file.filepath}")

if not any(imp.name == "pytest" for imp in file.imports):
file.add_import_from_import_string("import pytest")
file.add_import("import pytest")
print(f"➕ Added pytest import to {file.filepath}")

for cls in file.classes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The script automates the entire migration process in a few key steps:

```python
import_str = "import { useQuery, useSuspenseQueries } from '@tanstack/react-query'"
file.add_import_from_import_string(import_str)
file.add_import(import_str)
```

- Uses Codegen's import analysis to add required imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run(codebase: Codebase):

print(f"Processing {file.filepath}")
# Add the import statement
file.add_import_from_import_string(import_str)
file.add_import(import_str)
file_modified = False

# Iterate through all functions in the file
Expand Down
2 changes: 1 addition & 1 deletion docs/building-with-codegen/imports.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ for module, imports in module_imports.items():
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
file.add_import_from_import_string(
file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
Expand Down
2 changes: 1 addition & 1 deletion docs/building-with-codegen/react-and-jsx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,5 @@ for function in codebase.functions:

# Add import if needed
if not file.has_import("NewComponent"):
file.add_symbol_import(new_component)
file.add_import(new_component)
```
2 changes: 1 addition & 1 deletion docs/tutorials/flask-to-fastapi.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi

```python
# Add StaticFiles import
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")

# Mount static directory
file.add_symbol_from_source(
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/modularity.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if third_party_imports:
for imp in third_party_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if local_imports:
for imp in local_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)

# Organize imports in all files
for file in codebase.files:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/react-modernization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) =
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
file.add_import("import { useState, useEffect } from 'react';")
```

## Migrating to Modern Hooks
Expand All @@ -100,7 +100,7 @@ for function in codebase.functions:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
function.file.add_import_from_import_string(
function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
Expand Down
16 changes: 8 additions & 8 deletions src/codegen/cli/mcp/resources/system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2909,7 +2909,7 @@ def validate_data(data: dict) -> bool:
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
file.add_import_from_import_string(
file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
Expand Down Expand Up @@ -5180,7 +5180,7 @@ def build_graph(func, depth=0):
# Add import if needed
if not file.has_import("NewComponent"):
file.add_symbol_import(new_component)
file.add_import(new_component)
```
Expand Down Expand Up @@ -7316,17 +7316,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if third_party_imports:
for imp in third_party_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if local_imports:
for imp in local_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
# Organize imports in all files
for file in codebase.files:
Expand Down Expand Up @@ -8593,7 +8593,7 @@ class FeatureFlags:
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
file.add_import("import { useState, useEffect } from 'react';")
```
## Migrating to Modern Hooks
Expand All @@ -8611,7 +8611,7 @@ class FeatureFlags:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
function.file.add_import_from_import_string(
function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
Expand Down Expand Up @@ -9813,7 +9813,7 @@ def create_user():
```python
# Add StaticFiles import
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")
# Mount static directory
file.add_symbol_from_source(
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/sdk/codebase/node_classes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.function import Function
from codegen.sdk.core.import_resolution import Import
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.symbol import Symbol

Expand All @@ -33,7 +34,7 @@ class NodeClasses:
function_call_cls: type[FunctionCall]
comment_cls: type[Comment]
bool_conversion: dict[bool, str]
dynamic_import_parent_types: set[str]
dynamic_import_parent_types: set[type[Editable]]
symbol_map: dict[str, type[Symbol]] = field(default_factory=dict)
expression_map: dict[str, type[Expression]] = field(default_factory=dict)
type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict)
Expand Down
27 changes: 16 additions & 11 deletions src/codegen/sdk/codebase/node_classes/py_node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
from codegen.sdk.core.statements.switch_statement import SwitchStatement
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.dict import Dict
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.tuple import Tuple
Expand All @@ -29,6 +35,8 @@
from codegen.sdk.python.expressions.string import PyString
from codegen.sdk.python.expressions.union_type import PyUnionType
from codegen.sdk.python.statements.import_statement import PyImportStatement
from codegen.sdk.python.statements.match_case import PyMatchCase
from codegen.sdk.python.statements.with_statement import WithStatement


def parse_subscript(node: TSNode, file_node_id, ctx, parent):
Expand Down Expand Up @@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent):
False: "False",
},
dynamic_import_parent_types={
"function_definition",
"if_statement",
"try_statement",
"with_statement",
"else_clause",
"for_statement",
"except_clause",
"while_statement",
"match_statement",
"case_clause",
"finally_clause",
Function,
IfBlockStatement,
TryCatchStatement,
WithStatement,
ForLoopStatement,
WhileStatement,
SwitchStatement,
PyMatchCase,
},
)
27 changes: 14 additions & 13 deletions src/codegen/sdk/codebase/node_classes/ts_node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
from codegen.sdk.core.expressions.value import Value
from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
from codegen.sdk.core.statements.switch_case import SwitchCase
from codegen.sdk.core.statements.switch_statement import SwitchStatement
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters
from codegen.sdk.typescript.class_definition import TSClass
Expand Down Expand Up @@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args):
False: "false",
},
dynamic_import_parent_types={
"function_declaration",
"method_definition",
"arrow_function",
"if_statement",
"try_statement",
"else_clause",
"catch_clause",
"finally_clause",
"while_statement",
"for_statement",
"do_statement",
"switch_case",
"switch_statement",
Function,
IfBlockStatement,
TryCatchStatement,
ForLoopStatement,
WhileStatement,
SwitchStatement,
SwitchCase,
},
)
4 changes: 2 additions & 2 deletions src/codegen/sdk/core/class_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ def add_attribute(self, attribute: Attribute, include_dependencies: bool = False
file = self.file
for d in deps:
if isinstance(d, Import):
file.add_symbol_import(d.imported_symbol)
file.add_import(d.imported_symbol)
elif isinstance(d, Symbol):
file.add_symbol_import(d)
file.add_import(d)

@property
@noapidoc
Expand Down
Loading
Loading