@@ -295,6 +295,37 @@ def create_constant_nodes_and_return_specs(
295
295
return name_to_spec_dict
296
296
297
297
298
+ def _update_output_node_and_specs (exported_program : ExportedProgram ) -> None :
299
+ """
300
+ Update the output node and output specs in the exported program.
301
+ In case a constant node is used as output, we replace it with a clone of the constant node.
302
+ """
303
+ # Dict [node.name -> InputSpec]
304
+ updated_constant_placeholders = get_constant_placeholder_dict (exported_program )
305
+ output = exported_program .graph .find_nodes (op = "output" )[0 ]
306
+ output_nodes = cast (list [torch .fx .Node ], list (output .args [0 ]))
307
+ output_specs = exported_program .graph_signature .output_specs
308
+ assert len (output_nodes ) == len (output_specs )
309
+
310
+ for i in range (len (output_specs )):
311
+ out_node = output_nodes [i ]
312
+ if out_node not in updated_constant_placeholders :
313
+ continue
314
+
315
+ with exported_program .graph .inserting_after (out_node ):
316
+ new_node = exported_program .graph .call_function (
317
+ exir_ops .edge .aten .clone .default , (out_node ,)
318
+ )
319
+ assert "val" in out_node .meta
320
+ new_node .meta ["val" ] = out_node .meta ["val" ]
321
+ output_nodes [i ] = new_node
322
+
323
+ # Update the constant-propagated output node.
324
+ output_specs [i ].arg = TensorArgument (name = output_nodes [i ].name )
325
+
326
+ output .args = (output_nodes ,)
327
+
328
+
298
329
def constant_prop_pass (
299
330
exported_program : ExportedProgram ,
300
331
custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
@@ -341,12 +372,12 @@ def constant_prop_pass(
341
372
342
373
# Generate new input spec.
343
374
new_input_specs = []
344
- for node in exported_program .graph .nodes :
345
- if node .op != "placeholder" :
346
- continue
375
+ for node in exported_program .graph .find_nodes (op = "placeholder" ):
347
376
new_input_specs .append (name_to_spec_dict [node .name ])
348
377
exported_program .graph_signature .input_specs = new_input_specs
349
378
379
+ _update_output_node_and_specs (exported_program )
380
+
350
381
# Cleanup the graph.
351
382
exported_program .graph .eliminate_dead_code ()
352
383
exported_program .graph_module .recompile ()
0 commit comments