Skip to content

Commit 4ed40aa

Browse files
angelayifacebook-github-bot
authored andcommitted
Refactor delegation code (#4566)
Summary: X-link: pytorch/pytorch#132773 Pull Request resolved: #4566 Refactoring partitioner-based delegation to prepare for allowing buffer mutations in the delegate (following diff). Differential Revision: D60813405
1 parent f52d8ab commit 4ed40aa

File tree

4 files changed

+256
-180
lines changed

4 files changed

+256
-180
lines changed

exir/backend/backend_api.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
import copy
88
import logging
9-
from contextlib import contextmanager
9+
from contextlib import contextmanager, nullcontext
1010
from functools import singledispatch
1111
from typing import Generator, List
1212

1313
import torch
14+
import torch.utils._pytree as pytree
1415

1516
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1617
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -25,12 +26,11 @@
2526

2627
from executorch.exir.graph_module import get_control_flow_submodules
2728
from executorch.exir.lowered_backend_module import (
28-
_get_new_signature,
29+
_unsafe_adjust_original_program,
2930
create_exported_program_from_submodule,
3031
create_submodule_from_nodes,
3132
LoweredBackendModule,
3233
)
33-
from executorch.exir.pass_base import ExportPass
3434
from executorch.exir.program._fake_program import (
3535
get_fake_program,
3636
update_to_real_program,
@@ -193,6 +193,7 @@ def _partition_and_lower_one_graph_module(
193193
tagged_graph_module: torch.fx.GraphModule,
194194
partition_result: PartitionResult,
195195
owning_program: ExportedProgram,
196+
is_submodule: bool,
196197
) -> torch.fx.GraphModule:
197198
"""
198199
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +211,44 @@ def _partition_and_lower_one_graph_module(
210211

211212
logging.debug(f"For tag {tag}, found nodes {node_list}")
212213
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213-
submodule, call_module_node = create_submodule_from_nodes(
214-
tagged_graph_module, node_list, tag
214+
215+
replace_ctx = (
216+
tagged_graph_module._set_replace_hook(
217+
owning_program.graph_signature.get_replace_hook()
218+
)
219+
if not is_submodule
220+
else nullcontext()
215221
)
222+
with replace_ctx:
223+
submodule, call_module_node = create_submodule_from_nodes(
224+
tagged_graph_module, node_list, tag
225+
)
226+
216227
tagged_graph_module_output_node = [
217228
node for node in tagged_graph_module.graph.nodes if node.op == "output"
218-
]
229+
][0]
219230
submodule_output_node = [
220231
node for node in submodule.graph.nodes if node.op == "output"
221-
]
222-
# Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223-
submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta
232+
][0]
233+
# Copy the output node meta from the original output node, because
234+
# create_submodule_from_nodes doesn't cover the meta field
235+
submodule_output_node.meta = tagged_graph_module_output_node.meta
236+
submodule_output_node.meta["val"] = pytree.tree_map(
237+
lambda arg: arg.meta.get("val") if isinstance(arg, torch.fx.Node) else arg,
238+
submodule_output_node.args,
239+
)
224240
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
225241

226-
submodule_program = create_exported_program_from_submodule(
227-
submodule, owning_program, tag
242+
(
243+
submodule_program,
244+
toplevel_input_specs_to_delete,
245+
toplevel_output_specs_to_delete,
246+
) = create_exported_program_from_submodule(
247+
submodule,
248+
owning_program,
249+
tag,
250+
call_module_node,
251+
is_submodule,
228252
)
229253

230254
lowered_submodule = to_backend(
@@ -257,64 +281,48 @@ def _partition_and_lower_one_graph_module(
257281
call_delegate_node.meta["debug_handle"] = len(
258282
tagged_graph_module.graph.nodes
259283
)
284+
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
260285
call_module_node.replace_all_uses_with(call_delegate_node)
261286
tagged_graph_module.graph.erase_node(call_module_node)
262287

263-
# Delete all parameters/buffers consumed by the created exported program
264-
toplevel_signature = owning_program.graph_signature
265-
for node in tagged_graph_module.graph.nodes:
266-
# Find placeholders consumed by the delegate
267-
if node.op != "placeholder" or len(node.users) != 0:
268-
continue
269-
270-
if node.name in toplevel_signature.inputs_to_buffers:
271-
# Delete the consumed buffers
272-
buffer_name = toplevel_signature.inputs_to_buffers.get(node.name)
273-
if buffer_name in owning_program.state_dict:
274-
owning_program.state_dict.pop(buffer_name)
275-
else:
276-
owning_program.constants.pop(buffer_name)
277-
tagged_graph_module.graph.erase_node(node)
278-
elif node.name in toplevel_signature.inputs_to_parameters:
279-
# Delete the consumed parameters
280-
param_name = toplevel_signature.inputs_to_parameters.get(node.name)
281-
owning_program.state_dict.pop(param_name)
282-
tagged_graph_module.graph.erase_node(node)
283-
284-
tagged_graph_module.recompile()
288+
if is_submodule:
289+
assert len(toplevel_input_specs_to_delete) == 0
290+
assert len(toplevel_output_specs_to_delete) == 0
291+
elif (
292+
len(toplevel_input_specs_to_delete) > 0
293+
or len(toplevel_output_specs_to_delete) > 0
294+
):
295+
_unsafe_adjust_original_program(
296+
owning_program,
297+
call_delegate_node,
298+
toplevel_input_specs_to_delete,
299+
toplevel_output_specs_to_delete,
300+
)
301+
285302
return tagged_graph_module
286303

287304

288305
def _partition_and_lower(
289306
tagged_graph_module: torch.fx.GraphModule,
290307
partition_result: PartitionResult,
291308
owning_program: ExportedProgram,
309+
is_submodule: bool = False,
292310
) -> torch.fx.GraphModule:
293311
"""
294312
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295313
"""
296314

297315
partitioned_module = _partition_and_lower_one_graph_module(
298-
tagged_graph_module, partition_result, owning_program
316+
tagged_graph_module, partition_result, owning_program, is_submodule
299317
)
300318

301319
# Recursively partition and lower for submodules
302320
for name, submod, _node in get_control_flow_submodules(partitioned_module):
303321
partitioned_submodule = _partition_and_lower(
304-
submod, partition_result, owning_program
322+
submod, partition_result, owning_program, is_submodule=True
305323
)
306324
tagged_graph_module.add_module(name, partitioned_submodule)
307325

308-
# Run the export pass over the graph module so that the call delegate
309-
# nodes will match Edge dialect
310-
# TODO(angelayi): ExportPass will rerun the graph, however all we need
311-
# here is to add metadata to the call delegate nodes to preserve Edge
312-
# dialect. There's work going on to generate a random tensor from a
313-
# fake tensor and possibly it can help to address the issue.
314-
res = ExportPass()(tagged_graph_module)
315-
assert res is not None
316-
tagged_graph_module = res.graph_module
317-
318326
return tagged_graph_module
319327

320328

@@ -349,6 +357,8 @@ def to_backend(
349357
Returns:
350358
ExportedProgram: The input program, with some portions targeted for delegation.
351359
"""
360+
edge_program._validate()
361+
352362
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353363
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354364
try:
@@ -377,26 +387,22 @@ def to_backend(
377387
update_to_real_program(tagged_exported_program, edge_program)
378388

379389
for tag, _ in partitioner_result.partition_tags.items():
380-
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
390+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)
381391

382392
tagged_graph_module = _partition_and_lower(
383-
tagged_exported_program.graph_module, partitioner_result, edge_program
393+
tagged_exported_program.graph_module,
394+
partitioner_result,
395+
tagged_exported_program,
384396
)
385397

386-
# TODO(angelayi): Update this signature in a less manual way (maybe through
387-
# retracing)
388-
new_signature, new_state_dict, new_constants = _get_new_signature(
389-
edge_program,
390-
tagged_graph_module,
391-
)
392398
return ExportedProgram(
393399
root=tagged_graph_module,
394400
graph=tagged_graph_module.graph,
395-
graph_signature=new_signature,
396-
state_dict=new_state_dict,
397-
range_constraints=copy.deepcopy(edge_program.range_constraints),
398-
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
401+
graph_signature=tagged_exported_program.graph_signature,
402+
state_dict=tagged_exported_program.state_dict,
403+
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
404+
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
399405
example_inputs=None,
400-
constants=new_constants,
401-
verifiers=[edge_program.verifier],
406+
constants=tagged_exported_program.constants,
407+
verifiers=[tagged_exported_program.verifier],
402408
)

exir/backend/test/test_backends.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,21 +1270,3 @@ def forward(self, x: List[torch.Tensor]):
12701270

12711271
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12721272
gm(*inputs)
1273-
1274-
def test_get_new_signature(self):
1275-
class MyModule(torch.nn.Module):
1276-
def forward(self, x, y, z):
1277-
return x + y, y - z, z * x
1278-
1279-
ep = torch.export.export(
1280-
MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2))
1281-
)
1282-
sig, *_ = _get_new_signature(ep, ep.graph_module)
1283-
output_names = set()
1284-
self.assertEqual(len(sig.output_specs), 3)
1285-
for s in sig.output_specs:
1286-
self.assertEqual(s.kind, OutputKind.USER_OUTPUT)
1287-
self.assertIsInstance(s.arg, TensorArgument)
1288-
name = s.arg.name
1289-
self.assertNotIn(name, output_names)
1290-
output_names.add(name)

exir/backend/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def _assign_new_tag(
208208
def _maybe_duplicate_constant_nodes(
209209
tagged_exported_program: ExportedProgram,
210210
tag: str,
211-
owning_program: ExportedProgram,
212211
) -> None:
213212
"""
214213
If the constants node is shared by different tagged nodes, like
@@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes(
241240
copied_nodes = copied_nodes.union(
242241
duplicate_constant_node(tagged_exported_program, candidate_node)
243242
)
244-
duplicate_constant_node(owning_program, candidate_node)
245243
candidate_node_with_copies = candidate_nodes.union(copied_nodes)
246244
_assign_new_tag(tagged_exported_program, candidate_node_with_copies)
247245

0 commit comments

Comments
 (0)