Skip to content

Commit 403a775

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix handling constant inputs when delegating (#3031)
Summary: Pull Request resolved: #3031 Reviewed By: mcr229 Differential Revision: D56089279
1 parent 21fdc4e commit 403a775

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

exir/lowered_backend_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,17 +454,17 @@ def _get_new_signature( # noqa: C901
454454
new_state_dict = {}
455455
new_constants = {}
456456

457-
input_tensor_node_to_sig = {
458-
input_spec.arg.name: input_spec
459-
for input_spec in old_signature.input_specs
460-
if isinstance(input_spec.arg, TensorArgument)
461-
}
457+
placeholder_nodes = [
458+
node.name for node in original_program.graph.nodes if node.op == "placeholder"
459+
]
460+
assert len(placeholder_nodes) == len(old_signature.input_specs)
461+
input_node_to_sig = dict(zip(placeholder_nodes, old_signature.input_specs))
462462

463463
for node in gm.graph.nodes:
464464
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
465465
if node.op == "placeholder":
466466

467-
if node.name not in input_tensor_node_to_sig:
467+
if node.name not in input_node_to_sig:
468468
assert tag is not None
469469
input_specs.append(
470470
InputSpec(
@@ -475,7 +475,7 @@ def _get_new_signature( # noqa: C901
475475
)
476476
continue
477477

478-
orig_input_spec = input_tensor_node_to_sig[node.name]
478+
orig_input_spec = input_node_to_sig[node.name]
479479

480480
if not isinstance(orig_input_spec.arg, TensorArgument):
481481
input_specs.append(orig_input_spec)

0 commit comments

Comments
 (0)