Skip to content

Commit a28a111

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix handling constant inputs when delegating (#3031)
Summary: Pull Request resolved: #3031 Differential Revision: D56089279
1 parent 21fdc4e commit a28a111

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

exir/lowered_backend_module.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,17 +454,28 @@ def _get_new_signature( # noqa: C901
454454
new_state_dict = {}
455455
new_constants = {}
456456

457-
input_tensor_node_to_sig = {
458-
input_spec.arg.name: input_spec
459-
for input_spec in old_signature.input_specs
460-
if isinstance(input_spec.arg, TensorArgument)
461-
}
457+
if tag is None:
458+
# This is only the case where we're reconstructing the graph signature
459+
# for the toplevel graph
460+
placeholder_nodes = [
461+
node.name
462+
for node in original_program.graph.nodes
463+
if node.op == "placeholder"
464+
]
465+
assert len(placeholder_nodes) == len(old_signature.input_specs)
466+
input_node_to_sig = dict(zip(placeholder_nodes, old_signature.input_specs))
467+
else:
468+
input_node_to_sig = {
469+
input_spec.arg.name: input_spec
470+
for input_spec in old_signature.input_specs
471+
if isinstance(input_spec.arg, TensorArgument)
472+
}
462473

463474
for node in gm.graph.nodes:
464475
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
465476
if node.op == "placeholder":
466477

467-
if node.name not in input_tensor_node_to_sig:
478+
if node.name not in input_node_to_sig:
468479
assert tag is not None
469480
input_specs.append(
470481
InputSpec(
@@ -475,7 +486,7 @@ def _get_new_signature( # noqa: C901
475486
)
476487
continue
477488

478-
orig_input_spec = input_tensor_node_to_sig[node.name]
489+
orig_input_spec = input_node_to_sig[node.name]
479490

480491
if not isinstance(orig_input_spec.arg, TensorArgument):
481492
input_specs.append(orig_input_spec)

0 commit comments

Comments
 (0)