Skip to content

Commit 2c9c79c

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add rewrite function of observer
- Add function to rewrite observer after prepare_pt2e - Add corresponding test case
1 parent 1d43b3b commit 2c9c79c

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,14 @@ def forward(self, x):
11941194
return x.repeat(1, 2, 3, 4)
11951195

11961196

1197+
class ReWriteObs(torch.nn.Module):
1198+
def __init__(self):
1199+
super().__init__()
1200+
1201+
def forward(self, x):
1202+
return torch.nn.functional.relu(x).expand(3, 4)
1203+
1204+
11971205
class Reshape(torch.nn.Module):
11981206
def __init__(self):
11991207
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
generate_multi_graph_program,
4848
generate_qnn_executorch_compiler_spec,
4949
PyQnnManagerAdaptor,
50+
rewrite_prepared_observer,
5051
skip_annotation,
5152
to_edge_transform_and_lower_to_qnn,
5253
update_spill_fill_size,
@@ -2784,6 +2785,37 @@ def test_qnn_backend_dynamic_shape(self):
27842785
check_io_shape=True,
27852786
)
27862787

2788+
def test_qnn_backend_rewrite_prepared_observer(self):
2789+
from torch.ao.quantization import observer
2790+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2791+
2792+
module = ReWriteObs() # noqa: F405
2793+
sample_input = (torch.randn([3, 1]),)
2794+
module = torch.export.export(module, sample_input, strict=True).module()
2795+
2796+
quantizer = make_quantizer()
2797+
2798+
prepared = prepare_pt2e(module, quantizer)
2799+
prepared(*sample_input)
2800+
2801+
new_obs = observer.FixedQParamsObserver(
2802+
scale=0.004,
2803+
zero_point=0,
2804+
dtype=torch.uint8,
2805+
quant_min=0,
2806+
quant_max=255,
2807+
qscheme=torch.per_tensor_affine,
2808+
)
2809+
2810+
rewrite_prepared_observer(prepared, {"activation_post_process_2": new_obs})
2811+
self.assertTrue(
2812+
prepared.activation_post_process_1
2813+
== prepared.activation_post_process_2
2814+
== new_obs
2815+
)
2816+
quantized_module = convert_pt2e(prepared)
2817+
self.lower_module_and_test_output(quantized_module, sample_input)
2818+
27872819
def test_qnn_backend_skip_node_id_partitioner(self):
27882820
module = SimpleModel() # noqa: F405
27892821
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))

backends/qualcomm/utils/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import time
99
import warnings
10-
from collections import OrderedDict
10+
from collections import defaultdict, OrderedDict
1111
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1212

1313
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
@@ -1257,3 +1257,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
12571257
for node in gm.graph.nodes:
12581258
if dtype := get_quant_io_dtype_fn(node):
12591259
node.meta[QCOM_QUANTIZED_IO] = dtype
1260+
1261+
1262+
def rewrite_prepared_observer(
1263+
graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module]
1264+
):
1265+
"""
1266+
Rewrite the observer of the specified observer module name in the graph_module.
1267+
1268+
Example:
1269+
Consider the following graph_module after prepare_pt2e:
1270+
gm = prepare_pt2e(gm)
1271+
print(gm)
1272+
1273+
GraphModule(
1274+
(activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
1275+
(activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
1276+
(activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf)
1277+
(activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf)
1278+
)
1279+
1280+
new_observer = observer.FixedQParamsObserver(
1281+
scale=0.125,
1282+
zero_point=42,
1283+
dtype=torch.uint8,
1284+
quant_min=0,
1285+
quant_max=255,
1286+
qscheme=torch.per_tensor_affine,
1287+
)
1288+
1289+
Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer})
1290+
is equivalent to:
1291+
gm.activation_post_process_0 = new_observer
1292+
1293+
Note:
1294+
If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten.
1295+
"""
1296+
module_name_list = defaultdict(list)
1297+
for name, module in graph_module.named_modules(remove_duplicate=False):
1298+
module_name_list[module].append(name)
1299+
1300+
for name, new_observer in name_obs_dict.items():
1301+
old_module = getattr(graph_module, name, None)
1302+
1303+
if not old_module:
1304+
print(
1305+
f"[WARNING], No observer named as {name} found, please check the moudle name"
1306+
)
1307+
continue
1308+
for target_name in module_name_list[old_module]:
1309+
setattr(graph_module, target_name, new_observer)

0 commit comments

Comments
 (0)