@@ -454,17 +454,28 @@ def _get_new_signature( # noqa: C901
454
454
new_state_dict = {}
455
455
new_constants = {}
456
456
457
- input_tensor_node_to_sig = {
458
- input_spec .arg .name : input_spec
459
- for input_spec in old_signature .input_specs
460
- if isinstance (input_spec .arg , TensorArgument )
461
- }
457
+ if tag is None :
458
+ # This is only the case where we're reconstructing the graph signature
459
+ # for the toplevel graph
460
+ placeholder_nodes = [
461
+ node .name
462
+ for node in original_program .graph .nodes
463
+ if node .op == "placeholder"
464
+ ]
465
+ assert len (placeholder_nodes ) == len (old_signature .input_specs )
466
+ input_node_to_sig = dict (zip (placeholder_nodes , old_signature .input_specs ))
467
+ else :
468
+ input_node_to_sig = {
469
+ input_spec .arg .name : input_spec
470
+ for input_spec in old_signature .input_specs
471
+ if isinstance (input_spec .arg , TensorArgument )
472
+ }
462
473
463
474
for node in gm .graph .nodes :
464
475
is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
465
476
if node .op == "placeholder" :
466
477
467
- if node .name not in input_tensor_node_to_sig :
478
+ if node .name not in input_node_to_sig :
468
479
assert tag is not None
469
480
input_specs .append (
470
481
InputSpec (
@@ -475,7 +486,7 @@ def _get_new_signature( # noqa: C901
475
486
)
476
487
continue
477
488
478
- orig_input_spec = input_tensor_node_to_sig [node .name ]
489
+ orig_input_spec = input_node_to_sig [node .name ]
479
490
480
491
if not isinstance (orig_input_spec .arg , TensorArgument ):
481
492
input_specs .append (orig_input_spec )
0 commit comments