Skip to content

Commit 78688b7

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
VAD perf improvements - Improve aten::full performance (#3974)
Summary: Pull Request resolved: #3974 Reviewed By: dulinriley, zonglinpengmeta Differential Revision: D56524303 fbshipit-source-id: 5bdeb29d473446d0350d397e48eb70eeeaed16d3
1 parent 5e22836 commit 78688b7

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

backends/cadence/aot/passes.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,34 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Any, Dict, Tuple
8+
79
import torch
810
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
911
from executorch.exir.dialects._ops import ops as exir_ops
10-
from executorch.exir.pass_base import ExportPass, ProxyValue
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1113
from torch._subclasses import FakeTensor
1214
from torch.utils._pytree import tree_map_only
1315

1416

17+
# pyre-strict
18+
19+
# Similar to what's done in executorch/exir/pass_base.py
20+
Argument = Any # pyre-ignore
21+
22+
1523
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
1624
"""
1725
Replace the pt2 quantization ops with custom cadence quantization ops.
1826
"""
1927

20-
def call_operator(self, op, args, kwargs, meta):
28+
def call_operator(
29+
self,
30+
op, # pyre-ignore
31+
args: Tuple[Argument, ...],
32+
kwargs: Dict[str, Argument],
33+
meta: NodeMetadata,
34+
) -> ProxyValue:
2135
if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}:
2236
return super().call_operator(op, args, kwargs, meta)
2337

@@ -34,7 +48,13 @@ class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
3448
Replace the pt2 dequantization ops with custom cadence dequantization ops.
3549
"""
3650

37-
def call_operator(self, op, args, kwargs, meta):
51+
def call_operator(
52+
self,
53+
op, # pyre-ignore
54+
args: Tuple[Argument, ...],
55+
kwargs: Dict[str, Argument],
56+
meta: NodeMetadata,
57+
) -> ProxyValue:
3858
if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}:
3959
return super().call_operator(op, args, kwargs, meta)
4060

@@ -51,7 +71,13 @@ class ReplaceScalarTensorWithFullPass(ExportPass):
5171
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
5272
"""
5373

54-
def call_operator(self, op, args, kwargs, meta):
74+
def call_operator(
75+
self,
76+
op, # pyre-ignore
77+
args: Tuple[Argument, ...],
78+
kwargs: Dict[str, Argument],
79+
meta: NodeMetadata,
80+
) -> ProxyValue:
5581
if op not in {
5682
exir_ops.edge.aten.scalar_tensor.default,
5783
torch.ops.aten.scalar_tensor.default,
@@ -64,7 +90,7 @@ def call_operator(self, op, args, kwargs, meta):
6490
[1],
6591
args[0],
6692
),
67-
{},
93+
{"dtype": torch.float32},
6894
meta,
6995
)
7096

@@ -75,7 +101,13 @@ class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
75101
view_copy op
76102
"""
77103

78-
def call_operator(self, op, args, kwargs, meta):
104+
def call_operator(
105+
self,
106+
op, # pyre-ignore
107+
args: Tuple[Argument, ...],
108+
kwargs: Dict[str, Argument],
109+
meta: NodeMetadata,
110+
) -> ProxyValue:
79111
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
80112
# which allows us to cover all overloads.
81113
if get_edge_overload_packet(op) not in {
@@ -99,7 +131,13 @@ def call_operator(self, op, args, kwargs, meta):
99131

100132

101133
class RemoveZeroSizedCatArgsPass(ExportPass):
102-
def call_operator(self, op, args, kwargs, meta):
134+
def call_operator(
135+
self,
136+
op, # pyre-ignore
137+
args: Tuple[Argument, ...],
138+
kwargs: Dict[str, Argument],
139+
meta: NodeMetadata,
140+
) -> ProxyValue:
103141
if op != exir_ops.edge.aten.cat.default:
104142
return super().call_operator(op, args, kwargs, meta)
105143

@@ -122,6 +160,7 @@ def call_operator(self, op, args, kwargs, meta):
122160
# TODO(matthiascremon): confirm this is the best way to do this.
123161
if isinstance(result, FakeTensor):
124162
result.constant = result
163+
# pyre-ignore[7]: Incompatible return type.
125164
return torch.empty_like(result)
126165

127166
# If there was only one tensor in the new_args list,
@@ -130,7 +169,7 @@ def call_operator(self, op, args, kwargs, meta):
130169
return new_args[0]
131170

132171
# Otherwise, we replace args[0] with new_args.
133-
args = list(args)
134-
args[0] = new_args
172+
init_args = list(args)
173+
init_args[0] = new_args
135174
args = tuple(args)
136175
return super().call_operator(op, args, kwargs, meta)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from typing import Any, Dict, List, Tuple
810

911
import torch
@@ -31,6 +33,11 @@
3133
from torch.fx.passes.utils.fuser_utils import legalize_graph
3234

3335

36+
# Use this to avoid pyre errors
37+
# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
38+
ArgsType = Any
39+
40+
3441
# Helper function to get the args and kwargs for the linear replacement op
3542
def get_args_and_kwargs_linear(
3643
graph_module: GraphModule,
@@ -40,7 +47,7 @@ def get_args_and_kwargs_linear(
4047
dequants_weights: List[fx.Node],
4148
bias_inputs: List[fx.Node],
4249
quant_node: fx.Node,
43-
) -> Tuple[Tuple[Any], Dict[str, Any]]:
50+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
4451
"""
4552
Returns the args and kwargs for the linear replacement op.
4653
"""
@@ -98,7 +105,7 @@ def get_args_and_kwargs_layer_norm(
98105
dequants_inputs: List[fx.Node],
99106
other_inputs: List[fx.Node],
100107
quant_node: fx.Node,
101-
) -> Tuple[Tuple[Any], Dict[str, Any]]:
108+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
102109
"""
103110
Returns the args and kwargs for the layer norm replacement op.
104111
"""
@@ -167,7 +174,7 @@ def get_args_and_kwargs_matmul(
167174
inputs_inputs: List[fx.Node],
168175
dequants_inputs: List[fx.Node],
169176
quant_node: fx.Node,
170-
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
177+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
171178
requantize_scale = (
172179
# pyre-ignore[58]: Unsupported operand
173180
dequants_inputs[0].args[1]
@@ -203,7 +210,7 @@ def get_args_and_kwargs_conv(
203210
bias_inputs: List[fx.Node],
204211
quant_node: fx.Node,
205212
op_node: fx.Node,
206-
):
213+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
207214
weight_scale = dequants_weights[0].args[1]
208215
weight_zero_point = dequants_weights[0].args[2]
209216
# pyre-fixme[58]: Unsupported operand types
@@ -277,12 +284,14 @@ def get_args_and_kwargs_relu(
277284
graph_module: GraphModule,
278285
inputs_inputs: List[fx.Node],
279286
dequants_inputs: List[fx.Node],
280-
):
287+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
281288
# Make the args and kwargs for the replacement op
282289
args = tuple(inputs_inputs)
283290

284291
X_zero_point = graph_module.graph.call_function(
285-
torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2])
292+
torch.ops.aten.full.default,
293+
([1], dequants_inputs[0].args[2]),
294+
{"dtype": torch.int32},
286295
)
287296

288297
kwargs = {
@@ -292,8 +301,10 @@ def get_args_and_kwargs_relu(
292301

293302

294303
class QuantFusion(ExportPass):
295-
def __init__(self, patterns):
304+
# pyre-ignore[2]: Parameter `patterns` has no type specified
305+
def __init__(self, patterns) -> None:
296306
super().__init__()
307+
# pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified
297308
self.patterns = patterns
298309

299310
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
@@ -427,10 +438,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
427438
graph_module.recompile()
428439

429440
@classmethod
441+
# pyre-ignore[2]: Parameter `nodes` has no type specified
430442
def is_fused(cls, nodes) -> bool:
431443
return any(cls.__qualname__ in n.meta for n in nodes)
432444

433445
@classmethod
446+
# pyre-ignore[2]: Parameter `nodes` has no type specified
434447
def mark_fused(cls, nodes) -> bool:
435448
for n in nodes:
436449
# pyre-fixme[7]: Incompatible return type

0 commit comments

Comments
 (0)