Skip to content

Commit 92cab3c

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add support for quantized bmm (#4047)
Summary: Pull Request resolved: #4047 The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well. Differential Revision: D58959269
1 parent 38046ba commit 92cab3c

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from executorch.backends.cadence.aot.quantizer.patterns import (
1313
AddmmPattern,
14+
BmmPattern,
1415
Conv1dPattern,
1516
Conv2dPattern,
1617
LayerNormFunctionalPattern,
@@ -396,7 +397,9 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
396397
other_inputs,
397398
quant_node,
398399
)
399-
elif isinstance(pattern, MatmulPattern):
400+
elif isinstance(pattern, BmmPattern) or isinstance(
401+
pattern, MatmulPattern
402+
):
400403
args, kwargs = get_args_and_kwargs_matmul(
401404
inputs_inputs,
402405
dequants_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def replacement_op(self):
9595
return torch.ops.cadence.quantized_linear
9696

9797

98+
class BmmPattern(QuantizationPattern):
99+
def partition_types(self):
100+
return [torch.bmm]
101+
102+
def get_anchors(
103+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
104+
) -> PartitionAnchors:
105+
bmm_node = fused_partition[0].nodes[-1]
106+
107+
return PartitionAnchors(
108+
inputs=[(bmm_node, 0), (bmm_node, 1)],
109+
weights=[],
110+
biases=[],
111+
output=[(bmm_node,)],
112+
)
113+
114+
def replacement_op(self):
115+
return torch.ops.cadence.quantized_matmul.default
116+
117+
98118
class Conv1dPattern(QuantizationPattern):
99119
def partition_types(self) -> List[Type[torch.nn.Module]]:
100120
return [torch.nn.Conv1d]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.cadence.aot.quantizer.patterns import (
1111
AddmmPattern,
12+
BmmPattern,
1213
Conv1dPattern,
1314
Conv2dPattern,
1415
LayerNormFunctionalPattern,
@@ -133,6 +134,7 @@ def __init__(self):
133134
super().__init__(
134135
[
135136
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
137+
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
136138
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
137139
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
138140
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),

0 commit comments

Comments
 (0)