@@ -198,13 +198,14 @@ def mod_partition(node: torch.fx.Node):
198
198
199
199
split = split_module (mod_traced , module , mod_partition )
200
200
201
- const_gm , non_const_gm = split .submod_0 , split .submod_1
202
201
const_mod_name , non_const_mod_name = "submod_0" , "submod_1"
202
+ # Safely get submod_1 in case there are no non-const nodes
203
+ const_gm , non_const_gm = split .submod_0 , getattr (split , non_const_mod_name , None )
203
204
204
205
# The module that a call_module node refers to gets copied to submodules during split.
205
206
# The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
206
207
# attach inlined modules to `split` as it's the owning module now.
207
- for node in non_const_gm .graph .nodes :
208
+ for node in non_const_gm .graph .nodes if non_const_gm else [] :
208
209
if node .op == "call_module" :
209
210
setattr (split , node .target , getattr (non_const_gm , node .target ))
210
211
for node in const_gm .graph .nodes :
@@ -276,10 +277,11 @@ def mod_partition(node: torch.fx.Node):
276
277
277
278
split .graph .eliminate_dead_code ()
278
279
279
- # Finally, inline the non-constant submod into the split submod. This is so that the
280
- # original caller who may have passed in a graph module will get back out a graph
281
- # module whose graph is traced to the same granularity.
282
- _inline_module (split , non_const_mod_name )
280
+ # Finally, inline the non-constant submod (if it exists) into the split submod.
281
+ # This is so that the original caller who may have passed in a graph module will
282
+ # get back out a graph module whose graph is traced to the same granularity.
283
+ if hasattr (split , non_const_mod_name ):
284
+ _inline_module (split , non_const_mod_name )
283
285
284
286
return FoldedGraphModule (
285
287
split ,
0 commit comments