Skip to content

Commit 0b0abf6

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Enable quantized add
Summary: Add a pattern for add ops, and call the existing quantized add kernel. Note that it's not capable of broadcasting yet, so we add an explicit check for sizes. Since that change is not guaranteed to improve performance but can reduce accuracy, we don't include it in the default quantizer and instead introduce a quantizer for WakeWord. In the future, we probably will want to maintain multiple custom quantizers for models in the same way we do here. If we want them to live somewhere else than quantizer.py (which would bloat), we can extract that in a later diff. Note that WW stateful tests are broken already, so they show as red in the test box. Reviewed By: zonglinpeng Differential Revision: D69441041
1 parent b6ffe1a commit 0b0abf6

File tree

5 files changed

+144
-3
lines changed

5 files changed

+144
-3
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@
9999
"quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
100100
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
101101
)
102+
lib.define(
103+
"quantized_add.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
104+
"int Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
105+
)
102106
lib.define(
103107
"quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
104108
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
@@ -175,6 +179,10 @@
175179
"quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
176180
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
177181
)
182+
lib.define(
183+
"quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
184+
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
185+
)
178186
lib.define(
179187
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
180188
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
@@ -290,6 +298,42 @@ def dequantize_per_tensor_meta(
290298
return input.new_empty(input.size(), dtype=torch.float)
291299

292300

301+
@register_fake("cadence::quantized_add")
302+
def quantized_add_meta(
303+
X: torch.Tensor,
304+
X_scale: torch.Tensor,
305+
X_zero_point: torch.Tensor,
306+
Y: torch.Tensor,
307+
Y_scale: torch.Tensor,
308+
Y_zero_point: torch.Tensor,
309+
out_scale: float,
310+
out_zero_point: int,
311+
) -> torch.Tensor:
312+
out_size = X.size()
313+
if (list(X.size()) == [1]):
314+
out_size = Y.size()
315+
316+
return X.new_empty(out_size, dtype=X.dtype)
317+
318+
319+
@register_fake("cadence::quantized_add.per_tensor")
320+
def quantized_add_per_tensor_meta(
321+
X: torch.Tensor,
322+
X_scale: float,
323+
X_zero_point: int,
324+
Y: torch.Tensor,
325+
Y_scale: float,
326+
Y_zero_point: int,
327+
out_scale: float,
328+
out_zero_point: int,
329+
) -> torch.Tensor:
330+
out_size = X.size()
331+
if (list(X.size()) == [1]):
332+
out_size = Y.size()
333+
334+
return X.new_empty(out_size, dtype=X.dtype)
335+
336+
293337
@register_fake("cadence::quantized_linear")
294338
def quantized_linear_meta(
295339
src: torch.Tensor,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.cadence.aot.quantizer.patterns import (
13+
AddPattern,
1314
AddmmPattern,
1415
BmmPattern,
1516
Conv1dPattern,
@@ -41,6 +42,47 @@
4142
ReluPatterns = (ReluPattern0, ReluPattern1)
4243

4344

45+
def get_args_and_kwargs_add(
46+
graph_module: GraphModule,
47+
inputs_inputs: List[fx.Node],
48+
dequants_inputs: List[fx.Node],
49+
quant_node: fx.Node,
50+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
51+
X_scale_ = graph_module.graph.call_function(
52+
torch.ops.aten.full.default,
53+
([1], dequants_inputs[0].args[1]),
54+
{"dtype": torch.float},
55+
)
56+
X_zero_point_ = graph_module.graph.call_function(
57+
torch.ops.aten.full.default,
58+
([1], dequants_inputs[0].args[2]),
59+
{"dtype": torch.int32},
60+
)
61+
Y_scale_ = graph_module.graph.call_function(
62+
torch.ops.aten.full.default,
63+
([1], dequants_inputs[1].args[1]),
64+
{"dtype": torch.float},
65+
)
66+
Y_zero_point_ = graph_module.graph.call_function(
67+
torch.ops.aten.full.default,
68+
([1], dequants_inputs[1].args[2]),
69+
{"dtype": torch.int32},
70+
)
71+
args = (
72+
inputs_inputs[0],
73+
X_scale_,
74+
X_zero_point_,
75+
inputs_inputs[1],
76+
Y_scale_,
77+
Y_zero_point_,
78+
quant_node.args[1],
79+
quant_node.args[2],
80+
)
81+
82+
kwargs = {}
83+
return args, kwargs
84+
85+
4486
# Helper function to get the args and kwargs for the linear replacement op
4587
def get_args_and_kwargs_linear(
4688
graph_module: GraphModule,
@@ -339,7 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
339381
)
340382
for fused_partition in fused_partitions:
341383
anchors = pattern.get_anchors(graph_module, fused_partition)
342-
if not anchors:
384+
if not anchors or anchors.empty:
343385
continue
344386
if any(self.is_fused(p.nodes) for p in fused_partition):
345387
continue
@@ -385,7 +427,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
385427
inputs_inputs + weights_inputs + other_inputs + bias_inputs
386428
)
387429
kwargs = {}
388-
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
430+
if isinstance(pattern, AddPattern):
431+
args, kwargs = get_args_and_kwargs_add(
432+
graph_module,
433+
inputs_inputs,
434+
dequants_inputs,
435+
quant_node,
436+
)
437+
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
389438
args, kwargs = get_args_and_kwargs_conv(
390439
graph_module,
391440
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class PartitionAnchors:
4343
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
4444
default_factory=list
4545
)
46+
empty: bool = False
4647

4748

4849
class QuantizationPattern(ABC):
@@ -101,6 +102,36 @@ def replacement_op(self) -> OpOverload:
101102
return torch.ops.cadence.quantized_linear
102103

103104

105+
class AddPattern(QuantizationPattern):
106+
def partition_types(self) -> List[OpOverload]:
107+
return [torch.ops.aten.add.Tensor]
108+
109+
def get_anchors(
110+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
111+
) -> PartitionAnchors:
112+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
113+
add_node = fused_partition[0].nodes[-1]
114+
115+
# Bail if:
116+
# - the add node is not a tensor add
117+
# - the add node has kwargs (e.g. alpha)
118+
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(add_node.args[1], fx.Node)
119+
if not is_tensor_add or len(add_node.kwargs) > 0:
120+
return PartitionAnchors(
121+
empty=True,
122+
)
123+
124+
return PartitionAnchors(
125+
inputs=[(add_node, 0), (add_node, 1)],
126+
weights=[],
127+
biases=[],
128+
output=[(add_node,)],
129+
)
130+
131+
def replacement_op(self) -> OpOverload:
132+
return torch.ops.cadence.quantized_add.default
133+
134+
104135
class BmmPattern(QuantizationPattern):
105136
def partition_types(self) -> List[OpOverload]:
106137
return [torch.ops.aten.bmm.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.backends.cadence.aot.quantizer.patterns import (
1414
AddmmPattern,
15+
AddPattern,
1516
BmmPattern,
1617
Conv1dPattern,
1718
Conv2dPattern,
@@ -109,7 +110,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
109110
continue
110111

111112
anchors = self.pattern.get_anchors(model, fused_partition)
112-
if not anchors:
113+
if not anchors or anchors.empty:
113114
continue
114115
if is_annotated(
115116
[
@@ -211,3 +212,15 @@ def __init__(
211212
self,
212213
) -> None:
213214
super().__init__([])
215+
216+
217+
class CadenceWakeWordQuantizer(CadenceQuantizer):
218+
"""
219+
Quantizer for WakeWord, including add
220+
"""
221+
222+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
223+
if quantizers is None:
224+
quantizers = get_cadence_default_quantizers()
225+
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
226+
super().__init__(quantizers)

backends/cadence/aot/replace_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,6 +1839,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18391839
replaced_scalar_args: dict[
18401840
EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]]
18411841
] = {
1842+
exir_ops.edge.cadence.quantized_add: (
1843+
exir_ops.edge.cadence.quantized_add.per_tensor,
1844+
[1, 2, 4, 5],
1845+
),
18421846
exir_ops.edge.cadence.quantized_conv: (
18431847
exir_ops.edge.cadence.quantized_conv.per_tensor,
18441848
[8, 9, 12, 13],

0 commit comments

Comments
 (0)