Skip to content

Commit ebfb955

Browse files
angelayifacebook-github-bot
authored andcommitted
Refactor delegation code (#4566)
Summary: X-link: pytorch/pytorch#132773 Pull Request resolved: #4566 Refactoring partitioner-based delegation to prepare for allowing buffer mutations in the delegate (following diff). Reviewed By: cccclai Differential Revision: D60813405
1 parent caadd81 commit ebfb955

File tree

4 files changed

+252
-180
lines changed

4 files changed

+252
-180
lines changed

exir/backend/backend_api.py

Lines changed: 61 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
import copy
88
import logging
9-
from contextlib import contextmanager
9+
from contextlib import contextmanager, nullcontext
1010
from functools import singledispatch
1111
from typing import Generator, List
1212

1313
import torch
14+
import torch.utils._pytree as pytree
1415

1516
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1617
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -25,12 +26,11 @@
2526

2627
from executorch.exir.graph_module import get_control_flow_submodules
2728
from executorch.exir.lowered_backend_module import (
28-
_get_new_signature,
29+
_unsafe_adjust_original_program,
2930
create_exported_program_from_submodule,
3031
create_submodule_from_nodes,
3132
LoweredBackendModule,
3233
)
33-
from executorch.exir.pass_base import ExportPass
3434
from executorch.exir.program._fake_program import (
3535
get_fake_program,
3636
update_to_real_program,
@@ -193,6 +193,7 @@ def _partition_and_lower_one_graph_module(
193193
tagged_graph_module: torch.fx.GraphModule,
194194
partition_result: PartitionResult,
195195
owning_program: ExportedProgram,
196+
is_submodule: bool,
196197
) -> torch.fx.GraphModule:
197198
"""
198199
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +211,40 @@ def _partition_and_lower_one_graph_module(
210211

211212
logging.debug(f"For tag {tag}, found nodes {node_list}")
212213
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213-
submodule, call_module_node = create_submodule_from_nodes(
214-
tagged_graph_module, node_list, tag
214+
215+
replace_ctx = (
216+
tagged_graph_module._set_replace_hook(
217+
owning_program.graph_signature.get_replace_hook()
218+
)
219+
if not is_submodule
220+
else nullcontext()
215221
)
222+
with replace_ctx:
223+
submodule, call_module_node = create_submodule_from_nodes(
224+
tagged_graph_module, node_list, tag
225+
)
226+
216227
tagged_graph_module_output_node = [
217228
node for node in tagged_graph_module.graph.nodes if node.op == "output"
218-
]
229+
][0]
219230
submodule_output_node = [
220231
node for node in submodule.graph.nodes if node.op == "output"
221-
]
222-
# Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223-
submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta
232+
][0]
233+
# Copy the output node meta from the original output node, because
234+
# create_submodule_from_nodes doesn't cover the meta field
235+
submodule_output_node.meta = tagged_graph_module_output_node.meta
224236
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
225237

226-
submodule_program = create_exported_program_from_submodule(
227-
submodule, owning_program, tag
238+
(
239+
submodule_program,
240+
toplevel_input_specs_to_delete,
241+
toplevel_output_specs_to_delete,
242+
) = create_exported_program_from_submodule(
243+
submodule,
244+
owning_program,
245+
tag,
246+
call_module_node,
247+
is_submodule,
228248
)
229249

230250
lowered_submodule = to_backend(
@@ -257,64 +277,48 @@ def _partition_and_lower_one_graph_module(
257277
call_delegate_node.meta["debug_handle"] = len(
258278
tagged_graph_module.graph.nodes
259279
)
280+
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
260281
call_module_node.replace_all_uses_with(call_delegate_node)
261282
tagged_graph_module.graph.erase_node(call_module_node)
262283

263-
# Delete all parameters/buffers consumed by the created exported program
264-
toplevel_signature = owning_program.graph_signature
265-
for node in tagged_graph_module.graph.nodes:
266-
# Find placeholders consumed by the delegate
267-
if node.op != "placeholder" or len(node.users) != 0:
268-
continue
269-
270-
if node.name in toplevel_signature.inputs_to_buffers:
271-
# Delete the consumed buffers
272-
buffer_name = toplevel_signature.inputs_to_buffers.get(node.name)
273-
if buffer_name in owning_program.state_dict:
274-
owning_program.state_dict.pop(buffer_name)
275-
else:
276-
owning_program.constants.pop(buffer_name)
277-
tagged_graph_module.graph.erase_node(node)
278-
elif node.name in toplevel_signature.inputs_to_parameters:
279-
# Delete the consumed parameters
280-
param_name = toplevel_signature.inputs_to_parameters.get(node.name)
281-
owning_program.state_dict.pop(param_name)
282-
tagged_graph_module.graph.erase_node(node)
283-
284-
tagged_graph_module.recompile()
284+
if is_submodule:
285+
assert len(toplevel_input_specs_to_delete) == 0
286+
assert len(toplevel_output_specs_to_delete) == 0
287+
elif (
288+
len(toplevel_input_specs_to_delete) > 0
289+
or len(toplevel_output_specs_to_delete) > 0
290+
):
291+
_unsafe_adjust_original_program(
292+
owning_program,
293+
call_delegate_node,
294+
toplevel_input_specs_to_delete,
295+
toplevel_output_specs_to_delete,
296+
)
297+
285298
return tagged_graph_module
286299

287300

288301
def _partition_and_lower(
289302
tagged_graph_module: torch.fx.GraphModule,
290303
partition_result: PartitionResult,
291304
owning_program: ExportedProgram,
305+
is_submodule: bool = False,
292306
) -> torch.fx.GraphModule:
293307
"""
294308
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295309
"""
296310

297311
partitioned_module = _partition_and_lower_one_graph_module(
298-
tagged_graph_module, partition_result, owning_program
312+
tagged_graph_module, partition_result, owning_program, is_submodule
299313
)
300314

301315
# Recursively partition and lower for submodules
302316
for name, submod, _node in get_control_flow_submodules(partitioned_module):
303317
partitioned_submodule = _partition_and_lower(
304-
submod, partition_result, owning_program
318+
submod, partition_result, owning_program, is_submodule=True
305319
)
306320
tagged_graph_module.add_module(name, partitioned_submodule)
307321

308-
# Run the export pass over the graph module so that the call delegate
309-
# nodes will match Edge dialect
310-
# TODO(angelayi): ExportPass will rerun the graph, however all we need
311-
# here is to add metadata to the call delegate nodes to preserve Edge
312-
# dialect. There's work going on to generate a random tensor from a
313-
# fake tensor and possibly it can help to address the issue.
314-
res = ExportPass()(tagged_graph_module)
315-
assert res is not None
316-
tagged_graph_module = res.graph_module
317-
318322
return tagged_graph_module
319323

320324

@@ -349,6 +353,8 @@ def to_backend(
349353
Returns:
350354
ExportedProgram: The input program, with some portions targeted for delegation.
351355
"""
356+
edge_program._validate()
357+
352358
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353359
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354360
try:
@@ -377,26 +383,22 @@ def to_backend(
377383
update_to_real_program(tagged_exported_program, edge_program)
378384

379385
for tag, _ in partitioner_result.partition_tags.items():
380-
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
386+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)
381387

382388
tagged_graph_module = _partition_and_lower(
383-
tagged_exported_program.graph_module, partitioner_result, edge_program
389+
tagged_exported_program.graph_module,
390+
partitioner_result,
391+
tagged_exported_program,
384392
)
385393

386-
# TODO(angelayi): Update this signature in a less manual way (maybe through
387-
# retracing)
388-
new_signature, new_state_dict, new_constants = _get_new_signature(
389-
edge_program,
390-
tagged_graph_module,
391-
)
392394
return ExportedProgram(
393395
root=tagged_graph_module,
394396
graph=tagged_graph_module.graph,
395-
graph_signature=new_signature,
396-
state_dict=new_state_dict,
397-
range_constraints=copy.deepcopy(edge_program.range_constraints),
398-
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
397+
graph_signature=tagged_exported_program.graph_signature,
398+
state_dict=tagged_exported_program.state_dict,
399+
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
400+
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
399401
example_inputs=None,
400-
constants=new_constants,
401-
verifiers=[edge_program.verifier],
402+
constants=tagged_exported_program.constants,
403+
verifiers=[tagged_exported_program.verifier],
402404
)

exir/backend/test/test_backends.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,21 +1270,3 @@ def forward(self, x: List[torch.Tensor]):
12701270

12711271
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12721272
gm(*inputs)
1273-
1274-
def test_get_new_signature(self):
1275-
class MyModule(torch.nn.Module):
1276-
def forward(self, x, y, z):
1277-
return x + y, y - z, z * x
1278-
1279-
ep = torch.export.export(
1280-
MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2))
1281-
)
1282-
sig, *_ = _get_new_signature(ep, ep.graph_module)
1283-
output_names = set()
1284-
self.assertEqual(len(sig.output_specs), 3)
1285-
for s in sig.output_specs:
1286-
self.assertEqual(s.kind, OutputKind.USER_OUTPUT)
1287-
self.assertIsInstance(s.arg, TensorArgument)
1288-
name = s.arg.name
1289-
self.assertNotIn(name, output_names)
1290-
output_names.add(name)

exir/backend/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def _assign_new_tag(
208208
def _maybe_duplicate_constant_nodes(
209209
tagged_exported_program: ExportedProgram,
210210
tag: str,
211-
owning_program: ExportedProgram,
212211
) -> None:
213212
"""
214213
If the constants node is shared by different tagged nodes, like
@@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes(
241240
copied_nodes = copied_nodes.union(
242241
duplicate_constant_node(tagged_exported_program, candidate_node)
243242
)
244-
duplicate_constant_node(owning_program, candidate_node)
245243
candidate_node_with_copies = candidate_nodes.union(copied_nodes)
246244
_assign_new_tag(tagged_exported_program, candidate_node_with_copies)
247245

0 commit comments

Comments
 (0)