Skip to content

Commit 5719d28

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix delegates
Summary: With Yidi's migration in D46729844, the typing for the delegates is incorrect/inconsistent, with to_backend and preprocess taking in ExirExportedProgram sometimes, and other times just taking in an fx.GraphModule (pyre-strict is not enabled for the backend/ folder, so it is not caught by typechecking). This diff updates the delegate APIs to take in ExportedProgram. https://docs.google.com/document/d/1vpvjnGk1TWdnVbgrzuTBJTOFx2D06GTaMoZIrSipWAc/edit Reviewed By: mergennachin Differential Revision: D47252888 fbshipit-source-id: 647e26d49dcafd2d08b1a91c1565515c31159ef8
1 parent cc28f49 commit 5719d28

26 files changed

+320
-421
lines changed

backends/backend_api.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
from executorch.backends.partitioner import Partitioner, TPartitioner
1313
from executorch.backends.utils import is_identical_graph
1414
from executorch.exir import (
15-
attach_export_graph_metadata,
1615
CallSpec,
17-
ExirExportedProgram,
1816
ExportGraphSignature,
19-
get_exir_meta,
2017
MultiMethodExirExportedProgram,
21-
pytree,
2218
)
2319

2420
from executorch.exir.delegate import (
2521
create_submodule_from_nodes,
2622
executorch_call_delegate,
2723
get_lowered_module_name,
2824
LoweredBackendModule,
29-
patch_lowered_functions,
3025
)
31-
from executorch.exir.graph_module import get_control_flow_submodules
26+
from executorch.exir.graph_module import (
27+
attach_export_graph_metadata,
28+
ExirMetadata,
29+
get_control_flow_submodules,
30+
)
3231
from executorch.exir.pass_base import ExportPass
32+
from torch._export.exported_program import ExportedProgram
3333

3434

3535
@singledispatch
@@ -39,7 +39,7 @@ def to_backend(args):
3939
4040
def to_backend(
4141
backend_id: str,
42-
edge_graph_module: torch.fx.GraphModule,
42+
edge_graph_module: ExportedProgram,
4343
compile_specs: List[CompileSpec],
4444
) -> LoweredBackendModule:
4545
@@ -58,23 +58,24 @@ def to_backend(
5858
@to_backend.register
5959
def _(
6060
backend_id: str,
61-
edge_graph_module: torch.fx.GraphModule,
61+
edge_program: ExportedProgram,
6262
compile_specs: List[CompileSpec],
6363
) -> LoweredBackendModule:
6464
"""
6565
Add overloaded implementations for to_backend:
6666
def to_backend(
6767
backend_id: str,
68-
edge_graph_module: torch.fx.GraphModule,
68+
edge_program: ExportedProgram,
6969
compile_specs: List[CompileSpec],
7070
) -> LoweredBackendModule:
71-
Requires the passed in Module in Edge dialect to be executed in the backend identified
72-
by backend_id. The forward method of the given edge_graph_module will be
73-
targeted for execution.
71+
Requires the passed in exported program in Edge dialect to be executed in
72+
the backend identified by backend_id. The forward method of the given
73+
edge_graph_module will be targeted for execution.
7474
7575
Args:
7676
backend_id: The backend identifier.
77-
edge_graph_module: A module in Edge dialect to target for lowering to the backend.
77+
exported_program: An exported program in Edge dialect to target for
78+
lowering to the backend.
7879
compile_specs: A list of backend-specific objects with static
7980
metadata to configure the "compilation" process (e.g. it could be
8081
another dictionary itself).
@@ -83,7 +84,7 @@ def to_backend(
8384
LoweredBackendModule: A Module that has been lowered to the target backend.
8485
Internally, the lowered Module contains these special attributes:
8586
backend_id (str: backend id), __processed_module__ (str: a compiled module)
86-
compile_spec, original_module (original exported module)
87+
compile_spec, original_module (original exported program)
8788
8889
Raises:
8990
NotImplementedError: The backend is not implemented (e.g. it was not found).
@@ -93,18 +94,17 @@ def to_backend(
9394
# All backend implementation are final, so we don't need to consider nested subclasses.
9495
for cls in BackendDetails.__subclasses__():
9596
if backend_id == cls.__name__:
96-
copied_graph_module = copy.deepcopy(edge_graph_module)
97+
copied_edge_program = copy.deepcopy(edge_program)
9798
processed_bytes = cls.preprocess(
98-
copied_graph_module,
99+
copied_edge_program,
99100
compile_specs,
100101
)
101102
lowered_module = LoweredBackendModule(
102-
edge_graph_module,
103+
edge_program,
103104
backend_id,
104105
processed_bytes,
105106
compile_specs,
106107
)
107-
patch_lowered_functions(lowered_module)
108108
return lowered_module
109109
raise NotImplementedError(f"Backend {backend_id} was not found.")
110110

@@ -156,9 +156,26 @@ def _partition_and_lower(
156156
)
157157
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
158158

159+
# TODO(T158558782): Update the metadata once we migrate to torch.export
160+
submodule_program = ExportedProgram(
161+
submodule,
162+
submodule.graph,
163+
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
164+
CallSpec(None, None),
165+
{},
166+
{},
167+
[],
168+
)
169+
meta = ExirMetadata(
170+
in_spec=None,
171+
out_spec=None,
172+
update_spec=0,
173+
)
174+
attach_export_graph_metadata(submodule_program.graph_module, meta)
175+
159176
lowered_submodule = to_backend(
160177
delegation_spec.backend_id,
161-
submodule,
178+
submodule_program,
162179
delegation_spec.compile_specs,
163180
)
164181

@@ -199,22 +216,22 @@ def _partition_and_lower(
199216

200217
@to_backend.register
201218
def _(
202-
edge_graph_module: torch.fx.GraphModule,
219+
edge_program: ExportedProgram,
203220
partitioner: Type[TPartitioner],
204-
) -> torch.fx.GraphModule:
221+
) -> ExportedProgram:
205222
"""
206223
Add overloaded implementations for to_backend:
207224
def to_backend(
208-
edge_graph_module: torch.fx.GraphModule,
225+
edge_program: ExportedProgram,
209226
partitioner: Type[TPartitioner],
210-
) -> torch.fx.GraphModule
227+
) -> ExportedProgram:
211228
212229
Returns a semantically-equivalent program to the one given as input (represented
213230
as a graph module in Edge dialect), but with portions of the program targeted for
214231
delegation as determined by the partitioner.
215232
216233
Args:
217-
torch.fx.GraphModule: Program in Edge dialect.
234+
ExportedProgram: Program in Edge dialect.
218235
219236
partitioner: An instance of the Partitioner class type, in charge with tagging
220237
portions of the input program for delegation. A valid partitioner must have
@@ -224,8 +241,9 @@ def to_backend(
224241
225242
226243
Returns:
227-
torch.fx.GraphModule: The input program, with some portions targeted for delegation.
244+
ExportedProgram: The input program, with some portions targeted for delegation.
228245
"""
246+
edge_graph_module = edge_program.graph_module
229247
copied_graph_module = copy.deepcopy(edge_graph_module)
230248
# Call the partitioner on the given graph module
231249
partitioner_instance: Partitioner = partitioner()
@@ -249,7 +267,8 @@ def to_backend(
249267
tagged_graph_module, partitioner_instance
250268
)
251269

252-
return tagged_graph_module
270+
edge_program.graph_module = tagged_graph_module
271+
return edge_program
253272

254273

255274
def to_backend_multiple(
@@ -287,35 +306,18 @@ def to_backend_multiple(
287306
+ "partitioner subclass, or a partitioner subclass."
288307
)
289308

290-
method_name_to_delegated_gm = {}
309+
method_name_to_delegated_program = {}
291310
for method_name, prog in multi_method_program.methods().items():
292-
gm = prog.graph_module
293311
if isinstance(partitioner, dict):
294312
if method_name in partitioner:
295-
method_name_to_delegated_gm[method_name] = to_backend(
296-
gm, partitioner[method_name]
313+
method_name_to_delegated_program[method_name] = to_backend(
314+
prog, partitioner[method_name]
297315
)
298316
else:
299-
method_name_to_delegated_gm[method_name] = gm
317+
method_name_to_delegated_program[method_name] = prog
300318
else:
301-
method_name_to_delegated_gm[method_name] = to_backend(gm, partitioner)
302-
303-
def gm_to_program(gm: torch.fx.GraphModule):
304-
ep = ExirExportedProgram(
305-
gm,
306-
gm.graph,
307-
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
308-
CallSpec(None, None),
309-
{},
310-
{},
311-
[],
312-
True,
313-
)
314-
ep.graph_module.meta.update(gm.meta)
315-
attach_export_graph_metadata(ep.graph_module, get_exir_meta(gm))
316-
return ep
319+
method_name_to_delegated_program[method_name] = to_backend(
320+
prog, partitioner
321+
)
317322

318-
method_name_to_delegated_program = pytree.tree_map(
319-
gm_to_program, method_name_to_delegated_gm
320-
)
321323
return MultiMethodExirExportedProgram(method_name_to_delegated_program)

backends/backend_details.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from abc import ABC, abstractmethod
22

3-
from typing import Callable, Dict, List
4-
5-
import torch
3+
from typing import List
64

75
from executorch.backends.compile_spec_schema import CompileSpec
8-
from torch.fx.node import Node
6+
from torch._export.exported_program import ExportedProgram
97

108

119
def enforcedmethod(func):
@@ -31,7 +29,7 @@ class BackendDetails(ABC):
3129
enforced to implement this method.
3230
3331
Args:
34-
edge_ir_module: The original module. It will not be modified in place.
32+
edge_program: The original exported program. It will not be modified in place.
3533
backend_debug_handle_generator: A callable to map a graph to a dictionary (key is node, value is id)
3634
compile_specs: List of values needed for compilation
3735
@@ -45,7 +43,7 @@ class BackendDetails(ABC):
4543
# it's a virtual method and inheritant class needs to implement the actual function
4644
@abstractmethod
4745
def preprocess(
48-
edge_ir_module: torch.fx.GraphModule,
46+
edge_program: ExportedProgram,
4947
compile_specs: List[CompileSpec],
5048
) -> bytes:
5149
# Users should return a compiled blob - a binary that can run the desired

backends/qnnpack/qnnpack_preprocess.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.transforms import get_shape
1717

1818
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch._export.exported_program import ExportedProgram
1920

2021
T_Mm = exir_ops.edge.aten.mm.default
2122
T_Addmm = exir_ops.edge.aten.addmm.default
@@ -35,11 +36,11 @@ def _copy_buffer(storage: torch.UntypedStorage) -> bytes:
3536
class QnnpackBackend(BackendDetails):
3637
@staticmethod
3738
def preprocess(
38-
edge_ir_module: torch.fx.GraphModule,
39+
edge_program: ExportedProgram,
3940
compile_specs: List[CompileSpec],
4041
) -> bytes:
4142

42-
for node in edge_ir_module.graph.nodes:
43+
for node in edge_program.graph.nodes:
4344
# TODO(maxren): Follow this up by removing addm and mm nodes
4445
if node.op == "call_function":
4546
# Finding the linear node
@@ -52,7 +53,7 @@ def preprocess(
5253
weight = node.args[2]
5354
# For linear node, bias is known
5455
bias_tensor = getattr(
55-
edge_ir_module, node.args[0].target
56+
edge_program.graph_module, node.args[0].target
5657
).contiguous()
5758
# t_defualt node -> dequant node
5859
weight_dequant = weight.args[0]
@@ -66,7 +67,7 @@ def preprocess(
6667
weight_dequant = node.args[1]
6768
if len(node.args) > 2:
6869
bias_tensor = getattr(
69-
edge_ir_module, node.args[2].target
70+
edge_program.graph_module, node.args[2].target
7071
).contiguous()
7172
else:
7273
raise RuntimeError(
@@ -89,17 +90,22 @@ def preprocess(
8990
# deqaunt node -> quant node
9091
weight_quant = weight_dequant.args[0]
9192
# quant node -> tensor_constant
92-
weight_const = getattr(edge_ir_module, weight_quant.args[0].target)
93+
weight_const = getattr(
94+
edge_program.graph_module, weight_quant.args[0].target
95+
)
9396
if (
9497
weight_quant.target.__name__
9598
== "quantized_decomposed.quantize_per_channel.default"
9699
):
97100
# scale and zero_point are tensors
98101
weight_scale = weight_quant.args[1]
99-
scale_tensor = getattr(edge_ir_module, weight_scale.target)
102+
scale_tensor = getattr(
103+
edge_program.graph_module, weight_scale.target
104+
)
100105
weight_zeropoint = weight_quant.args[2]
101106
zp_tensor = (
102-
getattr(edge_ir_module, weight_zeropoint.target) + 128
107+
getattr(edge_program.graph_module, weight_zeropoint.target)
108+
+ 128
103109
)
104110
axis = weight_quant.args[3]
105111
# requantize weight to uint8

0 commit comments

Comments
 (0)