Skip to content

Commit e3c8e0c

Browse files
committed
[Qualcomm AI Engine Direct - Enable per channel linear op]
- Add per channel weight quantization for linear op - Bias quantization for per channel weight Linear op is not support yet
1 parent c06c89f commit e3c8e0c

File tree

6 files changed

+69
-17
lines changed

6 files changed

+69
-17
lines changed

backends/qualcomm/builders/op_linear.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ def define_node(
3939
linear_input_tensors.append(input_tensor_wrapper)
4040

4141
weight_node = node.args[1]
42+
if (
43+
quant_attrs := weight_node.meta.get("quant_attrs")
44+
) and "scales" in quant_attrs:
45+
# Dimension of weight is [m, n], per channel quant params is [m]
46+
# Change to [m, 1] to fit the tensor.div(s).add(z)
47+
quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1])
48+
quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1])
49+
4250
weight_tensor = get_parameter(weight_node, self.edge_program)
4351
weight_tensor_wrapper = self.define_tensor(
4452
weight_node,
@@ -50,6 +58,12 @@ def define_node(
5058

5159
if len(node.args) >= 3:
5260
bias_node = node.args[2]
61+
62+
# TODO remove this when qnn sdk support
63+
if "scales" in bias_node.meta.get("quant_attrs"):
64+
print(
65+
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
66+
)
5367
bias_tensor = get_parameter(bias_node, self.edge_program)
5468
bias_tensor_wrapper = self.define_tensor(
5569
bias_node,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __init__(self):
267267
self.custom_quant_annotations: Sequence[Callable] = []
268268
self.discard_nodes: Set[str] = set()
269269

270-
self.enable_per_channel_conv_quant: bool = True
270+
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
271271
# the weight quantized for activation 8 bits and 16 bits
272272
self.per_channel_weight_dtype: Dict = {
273273
"8bit_act": torch.int8,
@@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
290290
def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
291291
"""
292292
Priority:
293-
1. per channel config when enable_per_channel_conv_quant is True
293+
1. is one of use_per_channel_weight_quant_ops
294294
2. int8 / int16 config
295295
"""
296296
if type(op) == str:
297297
return
298298

299-
if self.enable_per_channel_conv_quant and op in [
300-
torch.ops.aten.conv1d.default,
301-
torch.ops.aten.conv2d.default,
302-
]:
299+
if op in self.use_per_channel_weight_quant_ops:
303300
if op in self.bit16_quant_ops:
304301
return get_ptq_per_channel_weight_config(
305302
torch.uint16, self.per_channel_weight_dtype["16bit_act"]
@@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
316313

317314
print(f"No quant config is implemented for op, {op}")
318315

316+
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
317+
if enable:
318+
self.use_per_channel_weight_quant_ops.update(ops)
319+
else:
320+
self.use_per_channel_weight_quant_ops.difference(ops)
321+
319322
def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None:
320323
for op in ops:
321324
assert (
@@ -368,8 +371,15 @@ def set_per_channel_weight_dtype(
368371
if weight_dtype_for_16bit_act:
369372
self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act
370373

371-
def set_per_channel_quant(self, enable: bool) -> None:
372-
self.enable_per_channel_conv_quant = enable
374+
def set_per_channel_conv_quant(self, enable: bool) -> None:
375+
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
376+
self._update_per_channel_weight_quant_ops(conv_ops, enable)
377+
378+
def set_per_channel_linear_quant(self, enable: bool) -> None:
379+
linear_ops = {
380+
torch.ops.aten.linear.default,
381+
}
382+
self._update_per_channel_weight_quant_ops(linear_ops, enable)
373383

374384
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
375385
model = RemoveClone()(model).graph_module

backends/qualcomm/quantizer/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,11 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
520520
)
521521
nodes_to_mark_annotated = [node, weight_node]
522522
if bias_node:
523-
_annotate_input_qspec_map(
524-
node,
525-
bias_node,
526-
quantization_config.bias,
527-
)
523+
if callable(quantization_config.bias):
524+
bias_config = quantization_config.bias(node)
525+
else:
526+
bias_config = quantization_config.bias
527+
_annotate_input_qspec_map(node, bias_node, bias_config)
528528
nodes_to_mark_annotated.append(bias_node)
529529
_annotate_output_qspec(node, quantization_config.output_activation)
530530
_mark_nodes_as_annotated(nodes_to_mark_annotated)

backends/qualcomm/tests/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(self, x):
409409

410410

411411
class Linear(torch.nn.Module):
412-
def __init__(self):
412+
def __init__(self, use_bias: bool = True):
413413
super().__init__()
414-
self.linear = torch.nn.Linear(4, 5).eval()
414+
self.linear = torch.nn.Linear(4, 5, use_bias).eval()
415415

416416
def forward(self, x):
417417
return self.linear(x)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,33 @@ def test_qnn_backend_16a4w_linear(self):
503503
module = Linear() # noqa: F405
504504
sample_input = (torch.randn([3, 4]),)
505505
module = self.get_qdq_module(
506-
module, sample_input, quant_dtype=QuantDtype.use_16a4w
506+
module,
507+
sample_input,
508+
quant_dtype=QuantDtype.use_16a4w,
509+
)
510+
self.lower_module_and_test_output(module, sample_input)
511+
512+
def test_qnn_backend_16a4w_per_channel_linear(self):
513+
module = Linear(use_bias=False) # noqa: F405
514+
sample_input = (torch.randn([3, 4]),)
515+
module = self.get_qdq_module(
516+
module,
517+
sample_input,
518+
is_linear_per_channel=True,
519+
quant_dtype=QuantDtype.use_16a4w,
520+
)
521+
self.lower_module_and_test_output(module, sample_input)
522+
523+
# Is not enabled in the current qnn sdk release
524+
@unittest.expectedFailure
525+
def test_qnn_backend_16a4w_per_channel_linear_with_bias(self):
526+
module = Linear() # noqa: F405
527+
sample_input = (torch.randn([3, 4]),)
528+
module = self.get_qdq_module(
529+
module,
530+
sample_input,
531+
is_linear_per_channel=True,
532+
quant_dtype=QuantDtype.use_16a4w,
507533
)
508534
self.lower_module_and_test_output(module, sample_input)
509535

backends/qualcomm/tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,16 @@ def get_qdq_module(
225225
module: torch.nn.Module,
226226
inputs: Tuple[torch.Tensor],
227227
is_conv_per_channel: Optional[bool] = True,
228+
is_linear_per_channel: Optional[bool] = False,
228229
custom_quant_annotations: Tuple[Callable] = (),
229230
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
230231
) -> torch.fx.GraphModule:
231232
m = torch._export.capture_pre_autograd_graph(module, inputs)
232233

233234
quantizer = QnnQuantizer()
234235
quantizer.add_custom_quant_annotations(custom_quant_annotations)
235-
quantizer.set_per_channel_quant(is_conv_per_channel)
236+
quantizer.set_per_channel_conv_quant(is_conv_per_channel)
237+
quantizer.set_per_channel_linear_quant(is_linear_per_channel)
236238

237239
if quant_dtype == QuantDtype.use_8a8w:
238240
pass # default setting

0 commit comments

Comments
 (0)