Skip to content

Commit 27ed99f

Browse files
authored
Fix quantized linear -> quantized fully connected replacement pass + add quantized fully connected per_tensor
Differential Revision: D66208417 Pull Request resolved: #6976
1 parent 9289b3f commit 27ed99f

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@
146146
"quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147147
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148148
)
149-
149+
lib.define(
150+
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152+
)
150153

151154
# ------------------------------------ #
152155
# Migrated from custom_ops.ymal #
@@ -192,6 +195,10 @@
192195
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193196
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194197
)
198+
lib.define(
199+
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201+
)
195202
lib.define(
196203
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197204
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595602
bias: torch.Tensor,
596603
in_zero_point: int,
597604
weight_zero_point: torch.Tensor,
605+
out_multiplier: torch.Tensor,
606+
out_shift: torch.Tensor,
607+
out_zero_point: int,
608+
offset: Optional[torch.Tensor],
609+
) -> torch.Tensor:
610+
# src comes in shape [leading_dims, in_dim]
611+
# weight comes in shape [out_dim, in_dim]
612+
# output comes in empty with shape [leading_dims, out_dim]
613+
out_size = list(src.size())
614+
weight_size = list(weight.size())
615+
assert len(weight_size) == 2
616+
out_size[-1] = weight_size[0]
617+
return src.new_empty(out_size, dtype=src.dtype)
618+
619+
620+
@register_fake("cadence::quantized_fully_connected.per_tensor")
621+
def quantized_fully_connected_per_tensor_meta(
622+
src: torch.Tensor,
623+
weight: torch.Tensor,
624+
bias: torch.Tensor,
625+
in_zero_point: int,
626+
weight_zero_point: int,
598627
out_multiplier: int,
599628
out_shift: int,
600629
out_zero_point: int,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607636
weight_size = list(weight.size())
608637
assert len(weight_size) == 2
609638
out_size[-1] = weight_size[0]
610-
return src.new_empty(out_size, dtype=torch.uint8)
639+
return src.new_empty(out_size, dtype=src.dtype)
611640

612641

613642
@register_fake("cadence::convolution")

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# 3. functions that replace an ATen op with another semantically equivalent ATen op.
1010
# 4. functions that concretize optional args.
1111

12+
# pyre-unsafe
13+
1214
import math
1315
from operator import neg
1416
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -1698,12 +1700,6 @@ def call_operator(self, op, args, kwargs, meta):
16981700
if leading_dims != 1:
16991701
return super().call_operator(op, args, kwargs, meta)
17001702

1701-
# If the op is quantized::linear, but per-channel quantized, bail.
1702-
if op == exir_ops.edge.cadence.quantized_linear.default:
1703-
weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1]
1704-
if weight.shape != [1]:
1705-
return super().call_operator(op, args, kwargs, meta)
1706-
17071703
# Replace the linear with fully connected op
17081704
return super().call_operator(
17091705
self.linear_to_fc_op[op],
@@ -1893,6 +1889,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18931889
exir_ops.edge.cadence.quantized_conv.per_tensor,
18941890
[8, 9, 12, 13],
18951891
),
1892+
exir_ops.edge.cadence.quantized_fully_connected: (
1893+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
1894+
[4, 5, 6],
1895+
),
18961896
exir_ops.edge.cadence.quantized_layer_norm: (
18971897
exir_ops.edge.cadence.quantized_layer_norm.per_tensor,
18981898
[1, 2],

0 commit comments

Comments
 (0)