Skip to content

Commit 3ffd24e

Browse files
authored
Enable quantized cat
Differential Revision: D69499329 Pull Request resolved: #8757
1 parent 17d4f04 commit 3ffd24e

File tree

3 files changed

+106
-7
lines changed

3 files changed

+106
-7
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AddmmPattern,
1414
AddPattern,
1515
BmmPattern,
16+
CatPattern,
1617
Conv1dPattern,
1718
Conv2dPattern,
1819
LayerNormPattern,
@@ -246,6 +247,16 @@ def get_args_and_kwargs_matmul(
246247
return args, kwargs
247248

248249

250+
def get_args_and_kwargs_cat(
251+
inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node
252+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
253+
args = tuple([inputs_inputs] + other_inputs)
254+
dim = op_node.args[1] if len(op_node.args) > 1 else 0
255+
# pyre-fixme[6]: Incompatible parameter type
256+
kwargs = {"dim": int(dim)}
257+
return args, kwargs
258+
259+
249260
def get_args_and_kwargs_conv(
250261
graph_module: GraphModule,
251262
inputs_inputs: List[fx.Node],
@@ -390,12 +401,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
390401
self.mark_fused(p.nodes)
391402

392403
dequants_inputs = []
393-
for node, idx in anchors.inputs:
404+
for node, idx, *_spec in anchors.inputs:
405+
arg = (
406+
node.args[idx]
407+
if isinstance(idx, int)
408+
else node.args[idx[0]][idx[1]]
409+
)
394410
if (
395-
node.args[idx].target
411+
arg.target
396412
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
397413
):
398-
dequants_inputs.append(node.args[idx])
414+
dequants_inputs.append(arg)
399415
dequants_weights = []
400416
for node, idx in anchors.weights:
401417
if (
@@ -434,6 +450,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
434450
dequants_inputs,
435451
quant_node,
436452
)
453+
elif isinstance(pattern, CatPattern):
454+
args, kwargs = get_args_and_kwargs_cat(
455+
inputs_inputs, other_inputs, op_node
456+
)
437457
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
438458
args, kwargs = get_args_and_kwargs_conv(
439459
graph_module,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@ class PartitionAnchors:
3333
is used for other types of input values as well as handling default parameters.
3434
"""
3535

36-
inputs: List[Tuple[fx.Node, int]] = field(default_factory=list)
36+
# Inputs can share quantization parameters
37+
inputs: List[
38+
Union[
39+
Tuple[fx.Node, Union[int, Tuple[int, int]]],
40+
Tuple[
41+
fx.Node,
42+
Union[int, Tuple[int, int]],
43+
SharedQuantizationSpec,
44+
],
45+
]
46+
] = field(default_factory=list)
3747
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
3848
biases: List[
3949
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
@@ -155,6 +165,52 @@ def replacement_op(self) -> OpOverload:
155165
return torch.ops.cadence.quantized_matmul.default
156166

157167

168+
class CatPattern(QuantizationPattern):
169+
def partition_types(self) -> List[OpOverload]:
170+
return [torch.ops.aten.cat.default]
171+
172+
def get_anchors(
173+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
174+
) -> PartitionAnchors:
175+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176+
cat_node = fused_partition[0].nodes[-1]
177+
178+
# Create args. The first argument does not have quant spec and
179+
# will inherit from the overall quant spec. All subsequent args
180+
# will share that spec.
181+
# Note that outpus also share that spec.
182+
args: List[
183+
Union[
184+
Tuple[fx.Node, Union[int, Tuple[int, int]]],
185+
Tuple[
186+
fx.Node,
187+
Union[int, Tuple[int, int]],
188+
SharedQuantizationSpec,
189+
],
190+
]
191+
] = [(cat_node, (0, 0))]
192+
for i in range(1, len(cat_node.args[0])):
193+
args.append(
194+
(
195+
cat_node,
196+
(0, i),
197+
SharedQuantizationSpec((cat_node.args[0][0], cat_node)),
198+
)
199+
)
200+
201+
return PartitionAnchors(
202+
inputs=args,
203+
weights=[],
204+
biases=[],
205+
output=[
206+
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
207+
],
208+
)
209+
210+
def replacement_op(self) -> OpOverload:
211+
return torch.ops.aten.cat.default
212+
213+
158214
class Conv1dPattern(QuantizationPattern):
159215
def partition_types(self) -> List[OpOverload]:
160216
return [torch.ops.aten.conv1d.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AddmmPattern,
1515
AddPattern,
1616
BmmPattern,
17+
CatPattern,
1718
Conv1dPattern,
1819
Conv2dPattern,
1920
LayerNormPattern,
@@ -144,17 +145,38 @@ def annotate_inputs(
144145
"quantization_annotation",
145146
QuantizationAnnotation(_annotated=True),
146147
)
148+
arg = (
149+
# pyre-ignore[16]: no attribute
150+
node.args[idx]
151+
if isinstance(idx, int)
152+
# pyre-ignore[16]: no attribute
153+
else node.args[idx[0]][idx[1]]
154+
)
155+
annotation.input_qspec_map[arg] = (
156+
custom_spec[0] if custom_spec else spec
157+
)
147158
# pyre-ignore[16]: no attribute
159+
node.meta["quantization_annotation"] = annotation
160+
161+
def annotate_weights_or_biases(
162+
weights_or_biases: List[Tuple[fx.Node, int]],
163+
spec: Optional[QuantizationSpec],
164+
) -> None:
165+
for node, idx, *custom_spec in weights_or_biases:
166+
annotation = node.meta.get(
167+
"quantization_annotation",
168+
QuantizationAnnotation(_annotated=True),
169+
)
148170
annotation.input_qspec_map[node.args[idx]] = (
149171
custom_spec[0] if custom_spec else spec
150172
)
151-
# pyre-ignore[16]: no attribute
152173
node.meta["quantization_annotation"] = annotation
153174

175+
# pyre-ignore[6]: incompatible parameter type
154176
annotate_inputs(anchors.inputs, input_act_qspec)
155-
annotate_inputs(anchors.weights, weight_qspec)
177+
annotate_weights_or_biases(anchors.weights, weight_qspec)
156178
# pyre-ignore[6]: incompatible parameter type
157-
annotate_inputs(anchors.biases, bias_qspec)
179+
annotate_weights_or_biases(anchors.biases, bias_qspec)
158180
return model
159181

160182
def validate(self, model: fx.GraphModule) -> None:
@@ -223,4 +245,5 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
223245
if quantizers is None:
224246
quantizers = get_cadence_default_quantizers()
225247
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
248+
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8uW8u))
226249
super().__init__(quantizers)

0 commit comments

Comments
 (0)