Skip to content

Commit 82798df

Browse files
committed
Qualcomm AI Engine Direct - Fixed layer norm quantization annotation for 16bit
- Fixed quantization annotation for layer norm in 16bit. - Add a unit test for 16a4w layer norm.
1 parent c06a708 commit 82798df

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

backends/qualcomm/quantizer/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,17 +1055,25 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
10551055

10561056
if _is_annotated([node]):
10571057
return
1058+
input_act_qspec = quantization_config.input_activation
10581059

10591060
_annotate_input_qspec_map(
10601061
node,
10611062
act_node,
1062-
quantization_config.input_activation,
1063-
)
1064-
_annotate_input_qspec_map(
1065-
node,
1066-
weight_node,
1067-
quantization_config.input_activation,
1063+
input_act_qspec,
10681064
)
1065+
if input_act_qspec.dtype == torch.int32:
1066+
_annotate_input_qspec_map(
1067+
node,
1068+
weight_node,
1069+
get_default_16bit_qnn_ptq_config().weight,
1070+
)
1071+
else:
1072+
_annotate_input_qspec_map(
1073+
node,
1074+
weight_node,
1075+
input_act_qspec,
1076+
)
10691077
nodes_to_mark_annotated = [node, weight_node]
10701078
if bias_node:
10711079
_annotate_input_qspec_map(

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,16 @@ def test_qnn_backend_16a4w_conv2d(self):
634634
)
635635
self.lower_module_and_test_output(module, sample_input)
636636

637+
def test_qnn_backend_16a4w_layer_norm(self):
638+
module = LayerNorm() # noqa: F405
639+
sample_input = (torch.randn(196, 768),)
640+
module = self.get_qdq_module(
641+
module,
642+
sample_input,
643+
quant_dtype=QuantDtype.use_16a4w,
644+
)
645+
self.lower_module_and_test_output(module, sample_input)
646+
637647
def test_qnn_backend_16a4w_linear(self):
638648
module = Linear() # noqa: F405
639649
sample_input = (torch.randn([3, 4]),)

0 commit comments

Comments
 (0)