Skip to content

CG-10899: file.add_import #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen-examples/examples/dict_to_schema/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(codebase: Codebase):

# Add imports if needed
if needs_imports:
file.add_import_from_import_string("from pydantic import BaseModel")
file.add_import("from pydantic import BaseModel")

if file_modified:
files_modified += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def setup_static_files(file):
print(f"📁 Processing file: {file.filepath}")

# Add import for StaticFiles
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")
print("✅ Added import: from fastapi.staticfiles import StaticFiles")

# Add app.mount for static file handling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ The codemod processes your codebase in several steps:
```python
def ensure_and_import(file):
if not any("and_" in imp.name for imp in file.imports):
file.add_import_from_import_string("from sqlalchemy import and_")
file.add_import("from sqlalchemy import and_")
```

- Automatically adds required SQLAlchemy imports (`and_`)
Expand Down
2 changes: 1 addition & 1 deletion codegen-examples/examples/sqlalchemy_soft_delete/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ensure_and_import(file):
"""Ensure the file has the necessary and_ import."""
if not any("and_" in imp.name for imp in file.imports):
print(f"File {file.filepath} does not import and_. Adding import.")
file.add_import_from_import_string("from sqlalchemy import and_")
file.add_import("from sqlalchemy import and_")


def clone_repo(repo_url: str, repo_path: Path) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def run(codebase: Codebase):

# Add necessary imports
if not cls.file.has_import("Mapped"):
cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n")
cls.file.add_import("from sqlalchemy.orm import Mapped\n")

if "Optional" in new_type and not cls.file.has_import("Optional"):
cls.file.add_import_from_import_string("from typing import Optional\n")
cls.file.add_import("from typing import Optional\n")

if "Decimal" in new_type and not cls.file.has_import("Decimal"):
cls.file.add_import_from_import_string("from decimal import Decimal\n")
cls.file.add_import("from decimal import Decimal\n")

if "datetime" in new_type and not cls.file.has_import("datetime"):
cls.file.add_import_from_import_string("from datetime import datetime\n")
cls.file.add_import("from datetime import datetime\n")

if class_modified:
classes_modified += 1
Expand Down
2 changes: 1 addition & 1 deletion codegen-examples/examples/unittest_to_pytest/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def convert_to_pytest_fixtures(file):
print(f"🔍 Processing file: {file.filepath}")

if not any(imp.name == "pytest" for imp in file.imports):
file.add_import_from_import_string("import pytest")
file.add_import("import pytest")
print(f"➕ Added pytest import to {file.filepath}")

for cls in file.classes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The script automates the entire migration process in a few key steps:

```python
import_str = "import { useQuery, useSuspenseQueries } from '@tanstack/react-query'"
file.add_import_from_import_string(import_str)
file.add_import(import_str)
```

- Uses Codegen's import analysis to add required imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run(codebase: Codebase):

print(f"Processing {file.filepath}")
# Add the import statement
file.add_import_from_import_string(import_str)
file.add_import(import_str)
file_modified = False

# Iterate through all functions in the file
Expand Down
2 changes: 1 addition & 1 deletion docs/building-with-codegen/imports.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ for module, imports in module_imports.items():
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
file.add_import_from_import_string(
file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
Expand Down
2 changes: 1 addition & 1 deletion docs/building-with-codegen/react-and-jsx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,5 @@ for function in codebase.functions:

# Add import if needed
if not file.has_import("NewComponent"):
file.add_symbol_import(new_component)
file.add_import(new_component)
```
2 changes: 1 addition & 1 deletion docs/tutorials/flask-to-fastapi.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi

```python
# Add StaticFiles import
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")

# Mount static directory
file.add_symbol_from_source(
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/modularity.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if third_party_imports:
for imp in third_party_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if local_imports:
for imp in local_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)

# Organize imports in all files
for file in codebase.files:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/react-modernization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) =
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
file.add_import("import { useState, useEffect } from 'react';")
```

## Migrating to Modern Hooks
Expand All @@ -100,7 +100,7 @@ for function in codebase.functions:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
function.file.add_import_from_import_string(
function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
Expand Down
16 changes: 8 additions & 8 deletions src/codegen/cli/mcp/resources/system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2909,7 +2909,7 @@ def validate_data(data: dict) -> bool:
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
file.add_import_from_import_string(
file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
Expand Down Expand Up @@ -5180,7 +5180,7 @@ def build_graph(func, depth=0):

# Add import if needed
if not file.has_import("NewComponent"):
file.add_symbol_import(new_component)
file.add_import(new_component)
```


Expand Down Expand Up @@ -7316,17 +7316,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if third_party_imports:
for imp in third_party_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)
file.insert_after_imports("") # Add newline

if local_imports:
for imp in local_imports:
file.add_import_from_import_string(imp.source)
file.add_import(imp.source)

# Organize imports in all files
for file in codebase.files:
Expand Down Expand Up @@ -8593,7 +8593,7 @@ class FeatureFlags:
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
file.add_import_from_import_string("import { useState, useEffect } from 'react';")
file.add_import("import { useState, useEffect } from 'react';")
```

## Migrating to Modern Hooks
Expand All @@ -8611,7 +8611,7 @@ class FeatureFlags:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
function.file.add_import_from_import_string(
function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
Expand Down Expand Up @@ -9813,7 +9813,7 @@ def create_user():

```python
# Add StaticFiles import
file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
file.add_import("from fastapi.staticfiles import StaticFiles")

# Mount static directory
file.add_symbol_from_source(
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/sdk/core/class_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ def add_attribute(self, attribute: Attribute, include_dependencies: bool = False
file = self.file
for d in deps:
if isinstance(d, Import):
file.add_symbol_import(d.imported_symbol)
file.add_import(d.imported_symbol)
elif isinstance(d, Symbol):
file.add_symbol_import(d)
file.add_import(d)

@property
@noapidoc
Expand Down
70 changes: 32 additions & 38 deletions src/codegen/sdk/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,62 +944,56 @@ def update_filepath(self, new_filepath: str) -> None:
imp.set_import_module(new_module_name)

@writer
def add_symbol_import(
self,
symbol: Symbol,
alias: str | None = None,
import_type: ImportType = ImportType.UNKNOWN,
is_type_import: bool = False,
) -> Import | None:
"""Adds an import to a file for a given symbol.

This method adds an import statement to the file for a specified symbol. If an import for the
symbol already exists, it returns the existing import instead of creating a new one.
def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add typing overloads to make sure the user doesn't pass the other parameters if imp is of type str

"""Adds an import to the file.

Args:
symbol (Symbol): The symbol to import.
alias (str | None): Optional alias for the imported symbol. Defaults to None.
import_type (ImportType): The type of import to use. Defaults to ImportType.UNKNOWN.
is_type_import (bool): Whether this is a type-only import. Defaults to False.

Returns:
Import | None: The existing import for the symbol or None if it was added.
"""
imports = self.imports
match = next((x for x in imports if x.imported_symbol == symbol), None)
if match:
return match

import_string = symbol.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
self.add_import_from_import_string(import_string)

@writer(commit=False)
def add_import_from_import_string(self, import_string: str) -> None:
"""Adds import to the file from a string representation of an import statement.

This method adds a new import statement to the file based on its string representation.
This method adds an import statement to the file. It can handle both string imports and symbol imports.
If the import already exists in the file, or is pending to be added, it won't be added again.
If there are existing imports, the new import will be added before the first import,
otherwise it will be added at the beginning of the file.

Args:
import_string (str): The string representation of the import statement to add.
imp (Symbol | str): Either a Symbol to import or a string representation of an import statement.
alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None.
import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN.
is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False.

Returns:
None
Import | None: The existing import for the symbol if found, otherwise None.
"""
if any(import_string.strip() in imp.source for imp in self.imports):
return
# Handle Symbol imports
if isinstance(imp, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the behavior to

  1. Convert import symbol to string (if it's a symbol)
  2. Insert the string with the same logic as the previous implementation?

# Handle string imports
import_string = imp
# Check for duplicate imports
if any(import_string.strip() in imp.source for imp in self.imports):
return None
else:
# Check for existing imports of this symbol
imports = self.imports
match = next((x for x in imports if x.imported_symbol == imp), None)
if match:
return match

# Convert symbol to import string
import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)

if import_string.strip() in self._pending_imports:
# Don't add the import string if it will already be added by another symbol
return
return None

# Add to pending imports and setup undo
self._pending_imports.add(import_string.strip())
self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear())

# Insert the import at the appropriate location
if self.imports:
self.imports[0].insert_before(import_string, priority=1)
else:
self.insert_before(import_string, priority=1)

return None

@writer
def add_symbol_from_source(self, source: str) -> None:
"""Adds a symbol to a file from a string representation.
Expand Down
18 changes: 9 additions & 9 deletions src/codegen/sdk/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,19 +329,19 @@ def _move_to_file(
# =====[ Imports - copy over ]=====
elif isinstance(dep, Import):
if dep.imported_symbol:
file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
else:
file.add_import_from_import_string(dep.source)
file.add_import(imp=dep.source)
else:
for dep in self.dependencies:
# =====[ Symbols - add back edge ]=====
if isinstance(dep, Symbol) and dep.is_top_level:
file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
elif isinstance(dep, Import):
if dep.imported_symbol:
file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
else:
file.add_import_from_import_string(dep.source)
file.add_import(imp=dep.source)

# =====[ Make a new symbol in the new file ]=====
file.add_symbol(self)
Expand All @@ -364,7 +364,7 @@ def _move_to_file(
# Here, we will add a "back edge" to the old file importing the symbol
elif strategy == "add_back_edge":
if is_used_in_file or any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages):
self.file.add_import_from_import_string(import_line)
self.file.add_import(imp=import_line)
# Delete the original symbol
self.remove()

Expand All @@ -374,7 +374,7 @@ def _move_to_file(
for usage in self.usages:
if isinstance(usage.usage_symbol, Import) and usage.usage_symbol.file != file:
# Add updated import
usage.usage_symbol.file.add_import_from_import_string(import_line)
usage.usage_symbol.file.add_import(import_line)
usage.usage_symbol.remove()
elif usage.usage_type == UsageType.CHAINED:
# Update all previous usages of import * to the new import name
Expand All @@ -383,11 +383,11 @@ def _move_to_file(
usage.match.get_name().edit(self.name)
if isinstance(usage.match, ChainedAttribute):
usage.match.edit(self.name)
usage.usage_symbol.file.add_import_from_import_string(import_line)
usage.usage_symbol.file.add_import(imp=import_line)

# Add the import to the original file
if is_used_in_file:
self.file.add_import_from_import_string(import_line)
self.file.add_import(imp=import_line)
# Delete the original symbol
self.remove()

Expand Down
Loading
Loading