Skip to content

Commit b92c102

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add rewrite function of observer
- Add function to rewrite observer after prepare_pt2e - Add corresponding test case
1 parent ba79cb6 commit b92c102

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,14 @@ def forward(self, x):
12211221
return x.repeat(1, 2, 3, 4)
12221222

12231223

1224+
class ReWriteObs(torch.nn.Module):
1225+
def __init__(self):
1226+
super().__init__()
1227+
1228+
def forward(self, x):
1229+
return torch.nn.functional.relu(x).expand(3, 4)
1230+
1231+
12241232
class Reshape(torch.nn.Module):
12251233
def __init__(self):
12261234
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
generate_qnn_executorch_compiler_spec,
5050
PyQnnManagerAdaptor,
5151
QnnPartitioner,
52+
rewrite_prepared_observer,
5253
skip_annotation,
5354
to_edge_transform_and_lower_to_qnn,
5455
update_spill_fill_size,
@@ -2913,6 +2914,37 @@ def test_qnn_backend_dynamic_shape(self):
29132914
check_io_shape=True,
29142915
)
29152916

2917+
def test_qnn_backend_rewrite_prepared_observer(self):
2918+
from torch.ao.quantization import observer
2919+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2920+
2921+
module = ReWriteObs() # noqa: F405
2922+
sample_input = (torch.randn([3, 1]),)
2923+
module = torch.export.export(module, sample_input, strict=True).module()
2924+
2925+
quantizer = make_quantizer()
2926+
2927+
prepared = prepare_pt2e(module, quantizer)
2928+
prepared(*sample_input)
2929+
2930+
new_obs = observer.FixedQParamsObserver(
2931+
scale=0.004,
2932+
zero_point=0,
2933+
dtype=torch.uint8,
2934+
quant_min=0,
2935+
quant_max=255,
2936+
qscheme=torch.per_tensor_affine,
2937+
)
2938+
2939+
rewrite_prepared_observer(prepared, {"activation_post_process_2": new_obs})
2940+
self.assertTrue(
2941+
prepared.activation_post_process_1
2942+
== prepared.activation_post_process_2
2943+
== new_obs
2944+
)
2945+
quantized_module = convert_pt2e(prepared)
2946+
self.lower_module_and_test_output(quantized_module, sample_input)
2947+
29162948
def test_qnn_backend_skip_node_id_partitioner(self):
29172949
module = SimpleModel() # noqa: F405
29182950
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))

backends/qualcomm/utils/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import operator
77
import warnings
8-
from collections import OrderedDict
8+
from collections import defaultdict, OrderedDict
99
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1010

1111
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
@@ -1038,3 +1038,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
10381038
for node in gm.graph.nodes:
10391039
if dtype := get_quant_io_dtype_fn(node):
10401040
node.meta[QCOM_QUANTIZED_IO] = dtype
1041+
1042+
1043+
def rewrite_prepared_observer(
1044+
graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module]
1045+
):
1046+
"""
1047+
Rewrite the observer of the specified observer module name in the graph_module.
1048+
1049+
Example:
1050+
Consider the following graph_module after prepare_pt2e:
1051+
gm = prepare_pt2e(gm)
1052+
print(gm)
1053+
1054+
GraphModule(
1055+
(activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
1056+
(activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
1057+
(activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf)
1058+
(activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf)
1059+
)
1060+
1061+
new_observer = observer.FixedQParamsObserver(
1062+
scale=0.125,
1063+
zero_point=42,
1064+
dtype=torch.uint8,
1065+
quant_min=0,
1066+
quant_max=255,
1067+
qscheme=torch.per_tensor_affine,
1068+
)
1069+
1070+
Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer})
1071+
is equivalent to:
1072+
gm.activation_post_process_0 = new_observer
1073+
1074+
Note:
1075+
If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten.
1076+
"""
1077+
module_name_list = defaultdict(list)
1078+
for name, module in graph_module.named_modules(remove_duplicate=False):
1079+
module_name_list[module].append(name)
1080+
1081+
for name, new_observer in name_obs_dict.items():
1082+
old_module = getattr(graph_module, name, None)
1083+
1084+
if not old_module:
1085+
print(
1086+
f"[WARNING], No observer named as {name} found, please check the moudle name"
1087+
)
1088+
continue
1089+
for target_name in module_name_list[old_module]:
1090+
setattr(graph_module, target_name, new_observer)

0 commit comments

Comments
 (0)