Skip to content

Commit 914f642

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix creating graph signature in delegation (#2580)
Summary: Pull Request resolved: #2580 Reviewed By: JacobSzwejbka Differential Revision: D55225522 fbshipit-source-id: 3f665ecebaa555879dc0adb4e2295f6522f4c6af
1 parent 001cc5f commit 914f642

File tree

1 file changed

+91
-58
lines changed

1 file changed

+91
-58
lines changed

exir/lowered_backend_module.py

Lines changed: 91 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from torch._subclasses import FakeTensor
2929
from torch.export.exported_program import (
30+
ConstantArgument,
3031
ExportedProgram,
3132
ExportGraphSignature,
3233
InputKind,
@@ -422,7 +423,7 @@ def arrange_graph_placeholders(
422423

423424

424425
# TODO Don't regenerate new signature manually.
425-
def _get_new_signature(
426+
def _get_new_signature( # noqa: C901
426427
original_program: ExportedProgram,
427428
gm: torch.fx.GraphModule,
428429
tag: Optional[str] = None,
@@ -431,6 +432,18 @@ def _get_new_signature(
431432
Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
432433
Dict[str, Union[torch.Tensor, torch.ScriptObject]],
433434
]:
435+
"""
436+
Args:
437+
tag: If tag is None, this means that we are constructing the graph
438+
signature for the toplevel graph, after delegation. We need to do this
439+
because sometimes delegates will swallow some parameters/buffers, so we
440+
need to update the graph signature/state dict to reflect these changes.
441+
Otherwise, if tag is not None, this means we are constructing the graph
442+
signature for the delegated modules. In this case, we need to look
443+
through the input nodes and see which ones were originally
444+
parameters/buffers, and lower them down to the delegate.
445+
"""
446+
434447
old_signature = original_program.graph_signature
435448

436449
input_specs = []
@@ -441,84 +454,104 @@ def _get_new_signature(
441454
new_state_dict = {}
442455
new_constants = {}
443456

444-
non_persistent_buffers = set(old_signature.non_persistent_buffers)
457+
input_tensor_node_to_sig = {
458+
input_spec.arg.name: input_spec
459+
for input_spec in old_signature.input_specs
460+
if isinstance(input_spec.arg, TensorArgument)
461+
}
445462

446463
for node in gm.graph.nodes:
447-
is_tagged = node.meta.get("delegation_tag", None) == tag
464+
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
448465
if node.op == "placeholder":
449-
if node.name in old_signature.inputs_to_parameters and is_tagged:
450-
parameter_name = old_signature.inputs_to_parameters[node.name]
451-
# add param to graph signature
452-
input_specs.append(
453-
InputSpec(
454-
kind=InputKind.PARAMETER,
455-
arg=TensorArgument(name=node.name),
456-
target=parameter_name,
457-
)
458-
)
459466

460-
# add param to state_dict
461-
new_state_dict[parameter_name] = original_program.state_dict[
462-
parameter_name
463-
]
464-
elif node.name in old_signature.inputs_to_buffers and is_tagged:
465-
buffer_name = old_signature.inputs_to_buffers[node.name]
466-
persistent = buffer_name not in non_persistent_buffers
467-
# add buffer to graph signature
467+
if node.name not in input_tensor_node_to_sig:
468+
assert tag is not None
468469
input_specs.append(
469470
InputSpec(
470-
kind=InputKind.BUFFER,
471+
kind=InputKind.USER_INPUT,
471472
arg=TensorArgument(name=node.name),
472-
target=buffer_name,
473-
persistent=persistent,
473+
target=None,
474474
)
475475
)
476+
continue
476477

477-
# add param to new_state_dict
478-
if persistent:
479-
new_state_dict[buffer_name] = original_program.state_dict[
480-
buffer_name
481-
]
482-
else:
483-
new_constants[buffer_name] = original_program.constants[buffer_name]
484-
elif (
485-
node.name in old_signature.inputs_to_lifted_tensor_constants
486-
and is_tagged
487-
):
488-
constant_name = old_signature.inputs_to_lifted_tensor_constants[
489-
node.name
490-
]
491-
# add constant to graph signature
492-
input_specs.append(
493-
InputSpec(
494-
kind=InputKind.CONSTANT_TENSOR,
495-
arg=TensorArgument(name=node.name),
496-
target=constant_name,
478+
orig_input_spec = input_tensor_node_to_sig[node.name]
479+
480+
if not isinstance(orig_input_spec.arg, TensorArgument):
481+
input_specs.append(orig_input_spec)
482+
483+
elif is_tagged:
484+
input_specs.append(orig_input_spec)
485+
486+
if orig_input_spec.kind in (InputKind.PARAMETER, InputKind.BUFFER):
487+
new_state_dict[orig_input_spec.target] = (
488+
original_program.state_dict[orig_input_spec.target]
497489
)
498-
)
490+
elif orig_input_spec.kind in (
491+
InputKind.CONSTANT_TENSOR,
492+
InputKind.CUSTOM_OBJ,
493+
):
494+
new_constants[orig_input_spec.target] = original_program.constants[
495+
orig_input_spec.target
496+
]
499497

500-
# add constant to new_constants
501-
new_constants[constant_name] = original_program.constants[constant_name]
502498
else:
503-
# not param, buffer, or lifted_tensor_constant then user input
504499
input_specs.append(
505500
InputSpec(
506501
kind=InputKind.USER_INPUT,
507502
arg=TensorArgument(name=node.name),
508503
target=None,
509504
)
510505
)
506+
511507
if node.op == "output":
512-
for output in pytree.tree_leaves((node.args, node.kwargs)):
513-
if not isinstance(output, torch.fx.Node):
514-
continue
515-
output_specs.append(
516-
OutputSpec(
517-
kind=OutputKind.USER_OUTPUT,
518-
arg=TensorArgument(name=output.name),
519-
target=None,
520-
)
521-
)
508+
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
509+
510+
if tag is not None:
511+
# We are constructing output_specs for the delegate outputs.
512+
# These don't have any buffer mutations.
513+
514+
for output_node in output_nodes:
515+
if not isinstance(output_node, torch.fx.Node):
516+
output_specs.append(
517+
OutputSpec(
518+
kind=OutputKind.USER_OUTPUT,
519+
arg=ConstantArgument(output_node),
520+
target=None,
521+
)
522+
)
523+
else:
524+
output_specs.append(
525+
OutputSpec(
526+
kind=OutputKind.USER_OUTPUT,
527+
arg=TensorArgument(name=output_node.name),
528+
target=None,
529+
)
530+
)
531+
532+
else:
533+
# We are reconstruting the toplevel module which contains
534+
# delegates. Delegation should not change the number of outputs
535+
# in the toplevel module, and it does not touch the mutated buffers
536+
537+
assert len(old_signature.output_specs) == len(output_nodes)
538+
for prev_output_spec, output_node in zip(
539+
old_signature.output_specs, output_nodes
540+
):
541+
if not isinstance(output_node, torch.fx.Node):
542+
assert isinstance(prev_output_spec.arg, ConstantArgument)
543+
output_specs.append(
544+
OutputSpec(
545+
kind=OutputKind.USER_OUTPUT,
546+
arg=ConstantArgument(output_node),
547+
target=None,
548+
)
549+
)
550+
551+
else:
552+
new_output_spec = copy.deepcopy(prev_output_spec)
553+
new_output_spec.arg.name = output_node.name
554+
output_specs.append(new_output_spec)
522555

523556
return new_signature, new_state_dict, new_constants
524557

0 commit comments

Comments
 (0)