Skip to content

Commit a83f95e

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add support for quantized bmm (#4047)
Summary: The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well. Use a decomposition for SDPA to make sure LLaMa bmms get quantized. Reviewed By: zonglinpengmeta, hsharma35 Differential Revision: D58959269
1 parent 074a81e commit a83f95e

File tree

5 files changed

+35
-7
lines changed

5 files changed

+35
-7
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ python_library(
2828
"compiler.py",
2929
],
3030
deps = [
31+
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
3132
":passes",
3233
":utils",
3334
"//caffe2:torch",

backends/cadence/aot/compiler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
ReplaceSqueezeAndUnsqueezeWithViewPass,
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
21-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
21+
from executorch.backends.cadence.aot.quantizer.quantizer import (
22+
CadenceGenericQuantizer,
23+
CadenceQuantizer,
24+
)
2225
from executorch.backends.cadence.aot.utils import model_is_quantized
2326
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
27+
from pyre_extensions import assert_is_instance
2428
from torch._export import capture_pre_autograd_graph
2529
from torch.ao.quantization.pt2e.export_utils import model_is_exported
2630
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -53,8 +57,10 @@ def quantize_pt2(
5357
converted_model = convert_pt2e(prepared_model)
5458

5559
# Get patterns and apply fusion of dq -> op -> q to qop
56-
# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
57-
patterns = [q.pattern for q in quantizer.quantizers]
60+
patterns = [
61+
assert_is_instance(q, CadenceGenericQuantizer).pattern
62+
for q in quantizer.quantizers
63+
]
5864
QuantFusion(patterns)(converted_model)
5965

6066
return converted_model

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from executorch.backends.cadence.aot.quantizer.patterns import (
1313
AddmmPattern,
14+
BmmPattern,
1415
Conv1dPattern,
1516
Conv2dPattern,
1617
LayerNormFunctionalPattern,
@@ -361,9 +362,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
361362
inputs_inputs + weights_inputs + other_inputs + bias_inputs
362363
)
363364
kwargs = {}
364-
if isinstance(pattern, Conv1dPattern) or isinstance(
365-
pattern, Conv2dPattern
366-
):
365+
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
367366
args, kwargs = get_args_and_kwargs_conv(
368367
graph_module,
369368
inputs_inputs,
@@ -396,7 +395,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
396395
other_inputs,
397396
quant_node,
398397
)
399-
elif isinstance(pattern, MatmulPattern):
398+
elif isinstance(pattern, (BmmPattern, MatmulPattern)):
400399
args, kwargs = get_args_and_kwargs_matmul(
401400
inputs_inputs,
402401
dequants_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def replacement_op(self):
9595
return torch.ops.cadence.quantized_linear
9696

9797

98+
class BmmPattern(QuantizationPattern):
99+
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
100+
return [torch.bmm]
101+
102+
def get_anchors(
103+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
104+
) -> PartitionAnchors:
105+
bmm_node = fused_partition[0].nodes[-1]
106+
107+
return PartitionAnchors(
108+
inputs=[(bmm_node, 0), (bmm_node, 1)],
109+
weights=[],
110+
biases=[],
111+
output=[(bmm_node,)],
112+
)
113+
114+
def replacement_op(self):
115+
return torch.ops.cadence.quantized_matmul.default
116+
117+
98118
class Conv1dPattern(QuantizationPattern):
99119
def partition_types(self) -> List[Type[torch.nn.Module]]:
100120
return [torch.nn.Conv1d]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.cadence.aot.quantizer.patterns import (
1111
AddmmPattern,
12+
BmmPattern,
1213
Conv1dPattern,
1314
Conv2dPattern,
1415
LayerNormFunctionalPattern,
@@ -133,6 +134,7 @@ def __init__(self):
133134
super().__init__(
134135
[
135136
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
137+
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
136138
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
137139
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
138140
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),

0 commit comments

Comments
 (0)