Skip to content

Commit 60ea5c6

Browse files
mcr229facebook-github-bot
authored andcommitted
always partition static attr and addmm op is supported (#354)
Summary: Pull Request resolved: #354 This is to enable lowering ViT Model. ViT's Multiheadattention is decomposed to many linears. addmm is not delegateable if it is not derived from torch.nn.linear. There are some addmms in ViT which are derived from MultiHeadAttention. As a result to improve performance we need to partition addmms in the operator list rather than by module. These changes are merged from: D49129703 as both changes are required to keep OD tests working For supported_operators that use static data, the data should always be partitioned along with that operator. This is required for adding addmm into supported_operator set because it allows us to partition in the weight and bias data. Reviewed By: digantdesai Differential Revision: D49129705 fbshipit-source-id: b7bec3e867d65328e4022d60d8c8f204998bc887
1 parent 2b7eb62 commit 60ea5c6

File tree

4 files changed

+14
-1
lines changed

4 files changed

+14
-1
lines changed

backends/xnnpack/partition/configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
exir_ops.edge.aten.elu.default,
6363
exir_ops.edge.aten.avg_pool2d.default,
6464
exir_ops.edge.aten.leaky_relu.default,
65+
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
6566
]
6667

6768
SUPPORTED_MODULES = [
@@ -95,7 +96,9 @@
9596
exir_ops.edge.aten.max_pool2d.default,
9697
exir_ops.edge.aten.constant_pad_nd.default,
9798
exir_ops.edge.aten.elu.default,
99+
exir_ops.edge.aten.t_copy.default,
98100
exir_ops.edge.aten.leaky_relu.default,
101+
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
99102
]
100103

101104
SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ def check_constraint(node, ep) -> bool:
109109
return _OP_SUPPORT_CONSTRAINTS.get(node.target, lambda node, ep: True)(node, ep)
110110

111111
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
112+
# Parameters are supported if any of their users are supported
113+
if is_param_node(self.ep, node):
114+
return any(
115+
self.is_node_supported(submodules, user) for user in node.users.keys()
116+
)
112117
# TODO - other ops?
113118
if node.op != "call_function":
114119
return False

backends/xnnpack/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ python_library(
1515
],
1616
deps = [
1717
"//caffe2:torch",
18+
"//executorch/backends/transforms:addmm_mm_to_linear",
1819
"//executorch/backends/transforms:lib",
1920
"//executorch/backends/xnnpack/partition:configs",
2021
"//executorch/backends/xnnpack/utils:xnnpack_utils",

backends/xnnpack/passes/convert_to_linear.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import torch
1111

1212
from executorch.backends.transforms import get_shape
13+
from executorch.backends.transforms.addmm_mm_to_linear import (
14+
apply_addmm_mm_to_linear_transform,
15+
)
1316
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
1417
from executorch.exir.dialects._ops import ops as exir_ops
1518

@@ -180,13 +183,14 @@ def call(self, graph_module: torch.fx.GraphModule):
180183
logger.debug(
181184
"Did not find any [add]mm target in source partitions, skipping the pass."
182185
)
183-
return PassResult(graph_module, False)
184186

185187
logger.debug("Converting [add]mm into Linear")
186188

187189
for node in src_node_dict.keys():
188190
self.create_linear(graph_module, node, src_node_dict[node])
189191

192+
graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
193+
190194
graph_module.recompile()
191195

192196
# Propagate metadata and retrace module

0 commit comments

Comments
 (0)