|
| 1 | +--- |
| 2 | +title: "Import Loops in PyTorch" |
| 3 | +icon: "arrows-rotate" |
| 4 | +iconType: "solid" |
| 5 | +description: "Identifying and visualizing import loops in the PyTorch codebase" |
| 6 | +--- |
| 7 | + |
| 8 | +In this post, we will visualize all import loops in the [PyTorch](https://github.com/pytorch/pytorch) codebase, propose a fix for one potentially unstable case, and use Codegen to refactor that fix. |
| 9 | + |
| 10 | +<Info> |
| 11 | +You can find the complete jupyter notebook in our [examples repository](https://github.com/codegen-sh/codegen-examples/tree/main/examples/removing_import_loops_in_pytorch). |
| 12 | +</Info> |
| 13 | + |
| 14 | +Import loops (or circular dependencies) occur when two or more Python modules depend on each other, creating a cycle. For example: |
| 15 | + |
| 16 | +```python |
| 17 | +# module_a.py |
| 18 | +from module_b import function_b |
| 19 | + |
| 20 | +# module_b.py |
| 21 | +from module_a import function_a |
| 22 | +``` |
| 23 | + |
| 24 | +While Python can handle some import cycles through its import machinery, they can lead to runtime errors, import deadlocks, or initialization order problems. |
| 25 | + |
| 26 | +Debugging import cycle errors can be a challenge, especially when they occur in large codebases. However, Codegen allows us to identify these loops through our visualization tools and fix them very deterministically and at scale. |
| 27 | + |
| 28 | +<Frame caption="Import loop in pytorch/torchgen/model.py"> |
| 29 | + <iframe |
| 30 | + width="100%" |
| 31 | + height="500px" |
| 32 | + scrolling="no" |
| 33 | + src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`} |
| 34 | + className="rounded-xl" |
| 35 | + style={{ |
| 36 | + backgroundColor: "#15141b", |
| 37 | + }} |
| 38 | + ></iframe> |
| 39 | +</Frame> |
| 40 | + |
| 41 | + |
| 42 | + |
| 43 | +## Visualize Import Loops in PyTorch |
| 44 | + |
| 45 | +Using Codegen, we discovered several import cycles in PyTorch's codebase. The code to gather and visualize these loops is as follows: |
| 46 | + |
| 47 | +```python |
| 48 | +G = nx.MultiDiGraph() |
| 49 | + |
| 50 | +# Add all edges to the graph |
| 51 | +for imp in codebase.imports: |
| 52 | + if imp.from_file and imp.to_file: |
| 53 | + edge_color = "red" if imp.is_dynamic else "black" |
| 54 | + edge_label = "dynamic" if imp.is_dynamic else "static" |
| 55 | + |
| 56 | + # Store the import statement and its metadata |
| 57 | + G.add_edge( |
| 58 | + imp.to_file.filepath, |
| 59 | + imp.from_file.filepath, |
| 60 | + color=edge_color, |
| 61 | + label=edge_label, |
| 62 | + is_dynamic=imp.is_dynamic, |
| 63 | + import_statement=imp, # Store the whole import object |
| 64 | + key=id(imp.import_statement), |
| 65 | + ) |
| 66 | +# Find strongly connected components |
| 67 | +cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1] |
| 68 | + |
| 69 | +print(f" Found {len(cycles)} import cycles:") |
| 70 | +for i, cycle in enumerate(cycles, 1): |
| 71 | + print(f"\nCycle #{i}:") |
| 72 | + print(f"Size: {len(cycle)} files") |
| 73 | + |
| 74 | + # Create subgraph for this cycle to count edges |
| 75 | + cycle_subgraph = G.subgraph(cycle) |
| 76 | + |
| 77 | + # Count total edges |
| 78 | + total_edges = cycle_subgraph.number_of_edges() |
| 79 | + print(f"Total number of imports in cycle: {total_edges}") |
| 80 | + |
| 81 | + # Count dynamic and static imports separately |
| 82 | + dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red") |
| 83 | + static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black") |
| 84 | + |
| 85 | + print(f"Number of dynamic imports: {dynamic_imports}") |
| 86 | + print(f"Number of static imports: {static_imports}") |
| 87 | +``` |
| 88 | + |
| 89 | +Here is one example visualized ⤵️ |
| 90 | + |
| 91 | +<Frame caption="Import loops in pytorch/torchgen/model.py"> |
| 92 | + <iframe |
| 93 | + width="100%" |
| 94 | + height="500px" |
| 95 | + scrolling="no" |
| 96 | + src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`} |
| 97 | + className="rounded-xl" |
| 98 | + style={{ |
| 99 | + backgroundColor: "#15141b", |
| 100 | + }} |
| 101 | + ></iframe> |
| 102 | +</Frame> |
| 103 | + |
| 104 | +Not all import cycles are problematic! Some cycles using dynamic imports can work perfectly fine: |
| 105 | + |
| 106 | +<Frame> |
| 107 | + <img src="/images/valid-import-loop.png" alt="Valid import loop example" /> |
| 108 | +</Frame> |
| 109 | + |
| 110 | + |
| 111 | +PyTorch prevents most circular import issues through dynamic imports which can be seen through the `import_symbol.is_dynamic` property. If any edge in a strongly connected component is dynamic, runtime conflicts are typically resolved. |
| 112 | + |
| 113 | +However, we discovered an import loop worth investigating between [`flex_decoding.py`](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py) and [`flex_attention.py`](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_attention.py): |
| 114 | + |
| 115 | +<img src="/images/problematic-import-loop.png" alt="Invalid import loop example" /> |
| 116 | + |
| 117 | +`flex_decoding.py` imports `flex_attention.py` *twice* — once dynamically and once at top-level. This mixed static/dynamic import pattern from the same module creates potential runtime instability. |
| 118 | + |
| 119 | +*Thus, we propose the following refactoring using Codegen*: |
| 120 | + |
| 121 | +## Move Shared Code to a Separate `utils.py` File |
| 122 | + |
| 123 | +```python |
| 124 | +# Create new utils file |
| 125 | +utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py") |
| 126 | + |
| 127 | +# Get the two files involved in the import cycle |
| 128 | +decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py") |
| 129 | +attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py") |
| 130 | +attention_file_path = "torch/_inductor/kernel/flex_attention.py" |
| 131 | +decoding_file_path = "torch/_inductor/kernel/flex_decoding.py" |
| 132 | + |
| 133 | +# Track symbols to move |
| 134 | +symbols_to_move = set() |
| 135 | + |
| 136 | +# Find imports from flex_attention in flex_decoding |
| 137 | +for imp in decoding_file.imports: |
| 138 | + if imp.from_file and imp.from_file.filepath == attention_file_path: |
| 139 | + # Get the actual symbol from flex_attention |
| 140 | + if imp.imported_symbol: |
| 141 | + symbols_to_move.add(imp.imported_symbol) |
| 142 | + |
| 143 | +# Move identified symbols to utils file |
| 144 | +for symbol in symbols_to_move: |
| 145 | + symbol.move_to_file(utils_file) |
| 146 | + |
| 147 | +print(f" Moved {len(symbols_to_move)} symbols to flex_utils.py") |
| 148 | +for symbol in symbols_to_move: |
| 149 | + print(symbol.name) |
| 150 | +``` |
| 151 | + |
| 152 | +Running this codemod will move all the shared symbols to a separate `utils.py` as well as resolve the imports from both files to point to the newly created file solving this potential unpredictable error that could lead issues later on. |
| 153 | + |
| 154 | + |
| 155 | +## Conclusion |
| 156 | + |
| 157 | +Import loops are a common challenge in large Python codebases. Using Codegen, no matter the repo size, you will gain some new insights into your codebase's import structure and be able to perform deterministic manipulations saving developer hours and future runtime errors. |
| 158 | + |
| 159 | +Want to try it yourself? Check out our [complete example](https://github.com/codegen-sh/codegen-examples/tree/main/examples/removing_import_loops_in_pytorch) of fixing import loops using Codegen. |
0 commit comments