Skip to content

Docs for import loops #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/images/import-loops.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/large-import-loop.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/problematic-import-loop.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/valid-import-loop.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/introduction/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ uv tool install codegen

## Quick Start with Jupyter

The [codgen notebook](/cli/notebook) command creates a virtual environment and opens a Jupyter notebook for quick prototyping. This is often the fastest way to get up and running.
The [codegen notebook](/cli/notebook) command creates a virtual environment and opens a Jupyter notebook for quick prototyping. This is often the fastest way to get up and running.

```bash
# Launch Jupyter with a demo notebook
Expand Down
1 change: 1 addition & 0 deletions docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"tutorials/react-modernization",
"tutorials/unittest-to-pytest",
"tutorials/sqlalchemy-1.6-to-2.0",
"tutorials/fixing-import-loops-in-pytorch",
"tutorials/python2-to-python3",
"tutorials/flask-to-fastapi"
]
Expand Down
260 changes: 260 additions & 0 deletions docs/tutorials/fixing-import-loops-in-pytorch.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
---
title: "Fixing Import Loops"
description: "Learn how to identify and fix problematic import loops using Codegen."
icon: "arrows-rotate"
iconType: "solid"
---
<Frame caption="Import loops in pytorch/torchgen/model.py">
<iframe
width="100%"
height="500px"
scrolling="no"
src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`}
className="rounded-xl"
style={{
backgroundColor: "#15141b",
}}
></iframe>
</Frame>


Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.

In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen.

<Info>
You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-examples/tree/main/examples/removing_import_loops_in_pytorch).
</Info>

## Overview

The steps to identify and fix import loops are as follows:
1. Detect import loops
2. Visualize them
3. Identify problematic cycles with mixed static/dynamic imports
4. Fix these cycles using Codegen

# Step 1: Detect Import Loops
- Create a graph
- Loop through imports in the codebase and add edges between the import files
- Find strongly connected components using Networkx (the import loops)
```python
G = nx.MultiDiGraph()

# Add all edges to the graph
for imp in codebase.imports:
if imp.from_file and imp.to_file:
edge_color = "red" if imp.is_dynamic else "black"
edge_label = "dynamic" if imp.is_dynamic else "static"

# Store the import statement and its metadata
G.add_edge(
imp.to_file.filepath,
imp.from_file.filepath,
color=edge_color,
label=edge_label,
is_dynamic=imp.is_dynamic,
import_statement=imp, # Store the whole import object
key=id(imp.import_statement),
)
# Find strongly connected components
cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1]

print(f"🔄 Found {len(cycles)} import cycles:")
for i, cycle in enumerate(cycles, 1):
print(f"\nCycle #{i}:")
print(f"Size: {len(cycle)} files")

# Create subgraph for this cycle to count edges
cycle_subgraph = G.subgraph(cycle)

# Count total edges
total_edges = cycle_subgraph.number_of_edges()
print(f"Total number of imports in cycle: {total_edges}")

# Count dynamic and static imports separately
dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red")
static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black")

print(f"Number of dynamic imports: {dynamic_imports}")
print(f"Number of static imports: {static_imports}")
```


## Understanding Import Cycles

Not all import cycles are problematic! Here's an example of a cycle that one may think would cause an error but it does not because due to using dynamic imports.

```python
# top level import in in APoT_tensor.py
from quantizer.py import objectA
```

```python
# dynamic import in quantizer.py
def some_func():
# dynamic import (evaluated when some_func() is called)
from APoT_tensor.py import objectB
```

<img src="/images/valid-import-loop.png" />

A dynamic import is an import defined inside of a function, method or any executable body of code which delays the import execution until that function, method or body of code is called.

You can use `imp.is_dynamic` to check if the import is dynamic allowing you to investigate imports that are handled more intentionally.

# Step 2: Visualize Import Loops
- Create a new subgraph to visualize one cycle
- color and label the edges based on their type (dynamic/static)
- visualize the cycle graph using `codebase.visualize(graph)`

```python
cycle = cycles[0]

def create_single_loop_graph(cycle):
cycle_graph = nx.MultiDiGraph() # Changed to MultiDiGraph to support multiple edges
cycle = list(cycle)
for i in range(len(cycle)):
for j in range(len(cycle)):
# Get all edges between these nodes from original graph
edge_data_dict = G.get_edge_data(cycle[i], cycle[j])
if edge_data_dict:
# For each edge between these nodes
for edge_key, edge_data in edge_data_dict.items():
# Add edge with all its attributes to cycle graph
cycle_graph.add_edge(cycle[i], cycle[j], **edge_data)
return cycle_graph


cycle_graph = create_single_loop_graph(cycle)
codebase.visualize(cycle_graph)
```

<Frame caption="Import loops in pytorch/torchgen/model.py">
<iframe
width="100%"
height="500px"
scrolling="no"
src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`}
className="rounded-xl"
style={{
backgroundColor: "#15141b",
}}
></iframe>
</Frame>


# Step 3: Identify problematic cycles with mixed static & dynamic imports

The import loops that we are really concerned about are those that have mixed static/dynamic imports.

Here's an example of a problematic cycle that we want to fix:

```python
# In flex_decoding.py
from .flex_attention import (
compute_forward_block_mn,
compute_forward_inner,
# ... more static imports
)

# Also in flex_decoding.py
def create_flex_decoding_kernel(*args, **kwargs):
from .flex_attention import set_head_dim_values # dynamic import
```

It's clear that there is both a top level and a dynamic import that imports from the *same* module. Thus, this can cause issues if not handled carefully.

<img src="/images/problematic-import-loop.png" />

Let's find these problematic cycles:

```python
def find_problematic_import_loops(G, sccs):
"""Find cycles where files have both static and dynamic imports between them."""
problematic_cycles = []

for i, scc in enumerate(sccs):
if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid)
continue
mixed_import_files = {} # (from_file, to_file) -> {dynamic: count, static: count}

# Check all file pairs in the cycle
for from_file in scc:
for to_file in scc:
if G.has_edge(from_file, to_file):
# Get all edges between these files
edges = G.get_edge_data(from_file, to_file)

# Count imports by type
dynamic_count = sum(1 for e in edges.values() if e["color"] == "red")
static_count = sum(1 for e in edges.values() if e["color"] == "black")

# If we have both types between same files, this is problematic
if dynamic_count > 0 and static_count > 0:
mixed_import_files[(from_file, to_file)] = {"dynamic": dynamic_count, "static": static_count, "edges": edges}

if mixed_import_files:
problematic_cycles.append({"files": scc, "mixed_imports": mixed_import_files, "index": i})

# Print findings
print(f"Found {len(problematic_cycles)} cycles with mixed imports:")
for i, cycle in enumerate(problematic_cycles):
print(f"\n⚠️ Problematic Cycle #{i + 1}:")
print(f"\n⚠️ Index #{cycle['index']}:")
print(f"Size: {len(cycle['files'])} files")

for (from_file, to_file), data in cycle["mixed_imports"].items():
print("\n📁 Mixed imports detected:")
print(f" From: {from_file}")
print(f" To: {to_file}")
print(f" Dynamic imports: {data['dynamic']}")
print(f" Static imports: {data['static']}")

return problematic_cycles

problematic_cycles = find_problematic_import_loops(G, cycles)
```

# Step 4: Fix the loop by moving the shared symbols to a separate `utils.py` file
One common fix to this problem to break this cycle is to move all the shared symbols to a separate `utils.py` file. We can do this using the method `symbol.move_to_file`:

```python
# Create new utils file
utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py")

# Get the two files involved in the import cycle
decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py")
attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py")
attention_file_path = "torch/_inductor/kernel/flex_attention.py"
decoding_file_path = "torch/_inductor/kernel/flex_decoding.py"

# Track symbols to move
symbols_to_move = set()

# Find imports from flex_attention in flex_decoding
for imp in decoding_file.imports:
if imp.from_file and imp.from_file.filepath == attention_file_path:
# Get the actual symbol from flex_attention
if imp.imported_symbol:
symbols_to_move.add(imp.imported_symbol)

# Move identified symbols to utils file
for symbol in symbols_to_move:
symbol.move_to_file(utils_file)

print(f"🔄 Moved {len(symbols_to_move)} symbols to flex_utils.py")
for symbol in symbols_to_move:
print(symbol.name)
```

```python
# run this command to have the changes take effect in the codebase
codebase.commit()
```

Next Steps
Verify all tests pass after the migration and fix other problematic import loops using the suggested strategies:
1. Move the shared symbols to a separate file
2. If a module needs imports only for type hints, consider using `if TYPE_CHECKING` from the `typing` module
3. Use lazy imports using `importlib` to load imports dynamically
2 changes: 1 addition & 1 deletion docs/tutorials/migrating-apis.mdx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: "Migrating APIs"
sidebarTitle: "API Migrations"
icon: "arrows-rotate"
icon: "webhook"
iconType: "solid"
---

Expand Down
2 changes: 1 addition & 1 deletion src/codegen/cli/utils/notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
]

DEMO_CELLS = [
##### [ CODGEN DEMO ] #####
##### [ CODEGEN DEMO ] #####
{
"cell_type": "markdown",
"source": """# Codegen Demo: FastAPI
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Typescript Analyzer Specific GitIgnores
node_modules
dist
package-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def setup():

@pytest.mark.skip("Skipping this test for now")
@pytest.mark.timeout(5, func_only=True)
@pytest.mark.skip(reason="Test is timing out and needs investigation") # Skip this test for now
@pytest.mark.parametrize("extension", ["txt", "py"])
def test_codebase_reset_correctness(extension: str, tmp_path):
codebase, files = setup_codebase(NUM_FILES, extension, tmp_path)
Expand Down