Skip to content

Commit 2eb6b04

Browse files
shewu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Fixed layer norm quantization annotation for 16bit (#5927)
Summary: - Fixed quantization annotation for layer norm in 16bit. - Add a unit test for 16a4w layer norm. Pull Request resolved: #5927 Reviewed By: kirklandsign Differential Revision: D63985345 Pulled By: cccclai fbshipit-source-id: 53fb7959b323d142aa06b3827c53cfd2c94e358d
1 parent ac2ae07 commit 2eb6b04

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

backends/qualcomm/quantizer/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,17 +1157,25 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
11571157

11581158
if _is_annotated([node]):
11591159
return
1160+
input_act_qspec = quantization_config.input_activation
11601161

11611162
_annotate_input_qspec_map(
11621163
node,
11631164
act_node,
1164-
quantization_config.input_activation,
1165-
)
1166-
_annotate_input_qspec_map(
1167-
node,
1168-
weight_node,
1169-
quantization_config.input_activation,
1165+
input_act_qspec,
11701166
)
1167+
if input_act_qspec.dtype == torch.int32:
1168+
_annotate_input_qspec_map(
1169+
node,
1170+
weight_node,
1171+
get_default_16bit_qnn_ptq_config().weight,
1172+
)
1173+
else:
1174+
_annotate_input_qspec_map(
1175+
node,
1176+
weight_node,
1177+
input_act_qspec,
1178+
)
11711179
nodes_to_mark_annotated = [node, weight_node]
11721180
if bias_node:
11731181
_annotate_input_qspec_map(

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,16 @@ def test_qnn_backend_16a4w_conv2d(self):
634634
)
635635
self.lower_module_and_test_output(module, sample_input)
636636

637+
def test_qnn_backend_16a4w_layer_norm(self):
638+
module = LayerNorm() # noqa: F405
639+
sample_input = (torch.randn(196, 768),)
640+
module = self.get_qdq_module(
641+
module,
642+
sample_input,
643+
quant_dtype=QuantDtype.use_16a4w,
644+
)
645+
self.lower_module_and_test_output(module, sample_input)
646+
637647
def test_qnn_backend_16a4w_linear(self):
638648
module = Linear() # noqa: F405
639649
sample_input = (torch.randn([3, 4]),)

0 commit comments

Comments
 (0)