Skip to content

Commit aa50879

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add support for quantized bmm (#4047)
Summary: The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well Reviewed By: dulinriley, zonglinpengmeta, hsharma35 Differential Revision: D58959269
1 parent e9aa542 commit aa50879

File tree

6 files changed

+102
-39
lines changed

6 files changed

+102
-39
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ python_library(
2828
"compiler.py",
2929
],
3030
deps = [
31+
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
3132
":passes",
3233
":utils",
3334
"//caffe2:torch",

backends/cadence/aot/compiler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
ReplaceSqueezeAndUnsqueezeWithViewPass,
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
21-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
21+
from executorch.backends.cadence.aot.quantizer.quantizer import (
22+
CadenceGenericQuantizer,
23+
CadenceQuantizer,
24+
)
2225
from executorch.backends.cadence.aot.utils import model_is_quantized
2326
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
27+
from pyre_extensions import assert_is_instance
2428
from torch._export import capture_pre_autograd_graph
2529
from torch.ao.quantization.pt2e.export_utils import model_is_exported
2630
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -53,8 +57,10 @@ def quantize_pt2(
5357
converted_model = convert_pt2e(prepared_model)
5458

5559
# Get patterns and apply fusion of dq -> op -> q to qop
56-
# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
57-
patterns = [q.pattern for q in quantizer.quantizers]
60+
patterns = [
61+
assert_is_instance(q, CadenceGenericQuantizer).pattern
62+
for q in quantizer.quantizers
63+
]
5864
QuantFusion(patterns)(converted_model)
5965

6066
return converted_model

backends/cadence/aot/quantizer/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ python_library(
1717
srcs = [
1818
"patterns.py",
1919
],
20+
typing = True,
2021
deps = [
2122
":utils",
2223
"//caffe2:torch",
@@ -28,7 +29,9 @@ python_library(
2829
srcs = [
2930
"quantizer.py",
3031
],
32+
typing = True,
3133
deps = [
34+
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
3235
":patterns",
3336
":utils",
3437
"//caffe2:torch",

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from executorch.backends.cadence.aot.quantizer.patterns import (
1313
AddmmPattern,
14+
BmmPattern,
1415
Conv1dPattern,
1516
Conv2dPattern,
1617
LayerNormFunctionalPattern,
@@ -361,9 +362,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
361362
inputs_inputs + weights_inputs + other_inputs + bias_inputs
362363
)
363364
kwargs = {}
364-
if isinstance(pattern, Conv1dPattern) or isinstance(
365-
pattern, Conv2dPattern
366-
):
365+
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
367366
args, kwargs = get_args_and_kwargs_conv(
368367
graph_module,
369368
inputs_inputs,
@@ -396,7 +395,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
396395
other_inputs,
397396
quant_node,
398397
)
399-
elif isinstance(pattern, MatmulPattern):
398+
elif isinstance(pattern, (BmmPattern, MatmulPattern)):
400399
args, kwargs = get_args_and_kwargs_matmul(
401400
inputs_inputs,
402401
dequants_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from abc import ABC, abstractmethod
810
from dataclasses import dataclass, field
9-
from typing import Any, Callable, List, Optional, Tuple, Type, Union
11+
from typing import Callable, List, Optional, Tuple, Type, Union
1012

1113
import torch
1214
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
1315

1416
from torch import fx
17+
from torch._ops import OpOverload
1518
from torch.ao.quantization.quantizer import (
1619
DerivedQuantizationSpec,
1720
SharedQuantizationSpec,
@@ -44,18 +47,22 @@ class PartitionAnchors:
4447

4548
class QuantizationPattern(ABC):
4649
@abstractmethod
47-
def partition_types(self):
50+
def partition_types(
51+
self,
52+
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
4853
"""
4954
List of types to be passed to find_sequential_partitions.
5055
"""
5156
pass
5257

5358
@abstractmethod
54-
def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]:
59+
def get_anchors(
60+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
61+
) -> Optional[PartitionAnchors]:
5562
pass
5663

5764
@abstractmethod
58-
def replacement_op(self) -> Callable[..., Any]:
65+
def replacement_op(self) -> OpOverload:
5966
"""
6067
Operator (most likely a custom one) that this partition should be fused into in
6168
the backend. Refer to the QuantFusion pass for examples.
@@ -91,10 +98,30 @@ def get_anchors(
9198
output=[(addmm_node,)],
9299
)
93100

94-
def replacement_op(self):
101+
def replacement_op(self) -> OpOverload:
95102
return torch.ops.cadence.quantized_linear
96103

97104

105+
class BmmPattern(QuantizationPattern):
106+
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
107+
return [torch.bmm]
108+
109+
def get_anchors(
110+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
111+
) -> PartitionAnchors:
112+
bmm_node = fused_partition[0].nodes[-1]
113+
114+
return PartitionAnchors(
115+
inputs=[(bmm_node, 0), (bmm_node, 1)],
116+
weights=[],
117+
biases=[],
118+
output=[(bmm_node,)],
119+
)
120+
121+
def replacement_op(self) -> OpOverload:
122+
return torch.ops.cadence.quantized_matmul.default
123+
124+
98125
class Conv1dPattern(QuantizationPattern):
99126
def partition_types(self) -> List[Type[torch.nn.Module]]:
100127
return [torch.nn.Conv1d]
@@ -129,7 +156,7 @@ def get_anchors(
129156
output=[(conv1d_node,)],
130157
)
131158

132-
def replacement_op(self):
159+
def replacement_op(self) -> OpOverload:
133160
return torch.ops.cadence.quantized_conv.default
134161

135162

@@ -167,15 +194,17 @@ def get_anchors(
167194
output=[(conv2d_node,)],
168195
)
169196

170-
def replacement_op(self):
197+
def replacement_op(self) -> OpOverload:
171198
return torch.ops.cadence.quantized_conv.default
172199

173200

174201
class LayerNormPattern(QuantizationPattern):
175-
def partition_types(self):
202+
def partition_types(self) -> List[Type[torch.nn.Module]]:
176203
return [torch.nn.LayerNorm]
177204

178-
def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
205+
def get_anchors(
206+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
207+
) -> PartitionAnchors:
179208
layer_norm_node = fused_partition[0].nodes[-1]
180209

181210
# Weights and biases are used as fp32 by our kernel, so they are
@@ -189,15 +218,17 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
189218
output=[(layer_norm_node,)],
190219
)
191220

192-
def replacement_op(self):
221+
def replacement_op(self) -> OpOverload:
193222
return torch.ops.cadence.quantized_layer_norm.default
194223

195224

196225
class LayerNormFunctionalPattern(QuantizationPattern):
197-
def partition_types(self):
226+
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
198227
return [torch.nn.functional.layer_norm]
199228

200-
def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
229+
def get_anchors(
230+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
231+
) -> PartitionAnchors:
201232
layer_norm_node = fused_partition[0].nodes[-1]
202233

203234
others = [(layer_norm_node, 1)]
@@ -221,7 +252,7 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
221252
output=[(layer_norm_node,)],
222253
)
223254

224-
def replacement_op(self):
255+
def replacement_op(self) -> OpOverload:
225256
return torch.ops.cadence.quantized_layer_norm.default
226257

227258

@@ -259,12 +290,12 @@ def get_anchors(
259290
output=[(linear_node,)],
260291
)
261292

262-
def replacement_op(self):
293+
def replacement_op(self) -> OpOverload:
263294
return torch.ops.cadence.quantized_linear.default
264295

265296

266297
class LinearFunctionalPattern(QuantizationPattern):
267-
def partition_types(self):
298+
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
268299
return [torch.nn.functional.linear]
269300

270301
def get_anchors(
@@ -297,12 +328,12 @@ def get_anchors(
297328
output=[(linear_node,)],
298329
)
299330

300-
def replacement_op(self):
331+
def replacement_op(self) -> OpOverload:
301332
return torch.ops.cadence.quantized_linear.default
302333

303334

304335
class MatmulPattern(QuantizationPattern):
305-
def partition_types(self):
336+
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
306337
return [torch.matmul]
307338

308339
def get_anchors(
@@ -317,7 +348,7 @@ def get_anchors(
317348
output=[(matmul_node,)],
318349
)
319350

320-
def replacement_op(self):
351+
def replacement_op(self) -> OpOverload:
321352
return torch.ops.cadence.quantized_matmul.default
322353

323354

@@ -339,5 +370,5 @@ def get_anchors(
339370
],
340371
)
341372

342-
def replacement_op(self):
373+
def replacement_op(self) -> OpOverload:
343374
return torch.ops.cadence.quantized_relu.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,35 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
7+
# pyre-strict
8+
9+
from typing import List, Optional, Tuple, Union
810

911
import torch
1012
from executorch.backends.cadence.aot.quantizer.patterns import (
1113
AddmmPattern,
14+
BmmPattern,
1215
Conv1dPattern,
1316
Conv2dPattern,
1417
LayerNormFunctionalPattern,
1518
LayerNormPattern,
1619
LinearFunctionalPattern,
1720
LinearPattern,
1821
MatmulPattern,
22+
QuantizationPattern,
1923
ReluPattern,
2024
)
2125
from executorch.backends.cadence.aot.quantizer.utils import (
2226
is_annotated,
2327
no_outside_users,
2428
)
29+
from pyre_extensions import assert_is_instance
2530

2631
from torch import fx
2732

2833
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
2934
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
30-
from torch.ao.quantization.quantizer import Quantizer
35+
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
3136
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3237
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
3338
OperatorConfig,
@@ -55,16 +60,18 @@
5560
observer_or_fake_quant_ctr=MinMaxObserver,
5661
)
5762

58-
bias_qspec = None
63+
bias_qspec: Optional[QuantizationSpec] = None
5964

6065

6166
class CadenceGenericQuantizer(Quantizer):
62-
def __init__(self, pattern, quantization_config):
67+
def __init__(
68+
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
69+
) -> None:
6370
super().__init__()
6471
self.pattern = pattern
6572
self.quantization_config = quantization_config
6673

67-
def annotate(self, model):
74+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
6875
fused_partitions = find_sequential_partitions(
6976
model,
7077
self.pattern.partition_types(),
@@ -94,25 +101,40 @@ def annotate(self, model):
94101
continue
95102

96103
for output, *custom_spec in anchors.output:
97-
output.meta["quantization_annotation"] = QuantizationAnnotation(
98-
output_qspec=custom_spec[0] if custom_spec else output_act_qspec,
99-
_annotated=True,
104+
assert_is_instance(output, fx.Node).meta["quantization_annotation"] = (
105+
QuantizationAnnotation(
106+
# pyre-ignore[6]: incompatible parameter type
107+
output_qspec=(
108+
custom_spec[0] if custom_spec else output_act_qspec
109+
),
110+
_annotated=True,
111+
)
100112
)
101113

102-
def annotate_inputs(inputs, spec):
114+
def annotate_inputs(
115+
inputs: Union[
116+
List[Tuple[fx.Node, int]],
117+
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
118+
],
119+
spec: Optional[QuantizationSpec],
120+
) -> None:
103121
for node, idx, *custom_spec in inputs:
104-
annotation = node.meta.get(
122+
_node = assert_is_instance(node, fx.Node)
123+
annotation = _node.meta.get(
105124
"quantization_annotation",
106125
QuantizationAnnotation(_annotated=True),
107126
)
108-
annotation.input_qspec_map[node.args[idx]] = (
127+
# pyre-ignore[6]: incompatible parameter type
128+
annotation.input_qspec_map[_node.args[idx]] = (
109129
custom_spec[0] if custom_spec else spec
110130
)
111-
node.meta["quantization_annotation"] = annotation
131+
_node.meta["quantization_annotation"] = annotation
112132

113133
annotate_inputs(anchors.inputs, input_act_qspec)
114134
annotate_inputs(anchors.weights, weight_qspec)
135+
# pyre-ignore[6]: incompatible parameter type
115136
annotate_inputs(anchors.biases, bias_qspec)
137+
return model
116138

117139
def validate(self, model: fx.GraphModule) -> None:
118140
pass
@@ -123,7 +145,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
123145

124146

125147
class CadenceQuantizer(ComposableQuantizer):
126-
def __init__(self):
148+
def __init__(self) -> None:
127149
static_qconfig = QuantizationConfig(
128150
act_qspec,
129151
act_qspec,
@@ -133,6 +155,7 @@ def __init__(self):
133155
super().__init__(
134156
[
135157
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
158+
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
136159
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
137160
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
138161
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),

0 commit comments

Comments
 (0)