Skip to content

Qualcomm AI Engine Direct - Add rewrite function of observer #10093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,14 @@ def forward(self, x):
return x.repeat(1, 2, 3, 4)


class ReWriteObs(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.nn.functional.relu(x).expand(3, 4)


class Reshape(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
31 changes: 31 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
generate_qnn_executorch_compiler_spec,
PyQnnManagerAdaptor,
QnnPartitioner,
rewrite_prepared_observer,
skip_annotation,
to_edge_transform_and_lower_to_qnn,
update_spill_fill_size,
Expand Down Expand Up @@ -3058,6 +3059,36 @@ def test_qnn_backend_dynamic_shape(self):
check_io_shape=True,
)

def test_qnn_backend_rewrite_prepared_observer(self):
from torchao.quantization.pt2e import FixedQParamsObserver

module = ReWriteObs() # noqa: F405
sample_input = (torch.randn([3, 1]),)
module = torch.export.export(module, sample_input, strict=True).module()

quantizer = make_quantizer()

prepared = prepare_pt2e(module, quantizer)
prepared(*sample_input)

new_obs = FixedQParamsObserver(
scale=0.004,
zero_point=0,
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
)

rewrite_prepared_observer(prepared, {"activation_post_process_2": new_obs})
Copy link
Contributor

@jerryzh168 jerryzh168 May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxu does your callback work for this? maybe you can share your example

self.assertTrue(
prepared.activation_post_process_1
== prepared.activation_post_process_2
== new_obs
)
quantized_module = convert_pt2e(prepared)
self.lower_module_and_test_output(quantized_module, sample_input)

def test_qnn_backend_skip_node_id_partitioner(self):
module = SimpleModel() # noqa: F405
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
Expand Down
52 changes: 51 additions & 1 deletion backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import operator
import warnings
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
Expand Down Expand Up @@ -1038,3 +1038,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
for node in gm.graph.nodes:
if dtype := get_quant_io_dtype_fn(node):
node.meta[QCOM_QUANTIZED_IO] = dtype


def rewrite_prepared_observer(
graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module]
):
"""
Rewrite the observer of the specified observer module name in the graph_module.

Example:
Consider the following graph_module after prepare_pt2e:
gm = prepare_pt2e(gm)
print(gm)

GraphModule(
(activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
(activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
(activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf)
(activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf)
)

new_observer = observer.FixedQParamsObserver(
scale=0.125,
zero_point=42,
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
)

Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer})
is equivalent to:
gm.activation_post_process_0 = new_observer

Note:
If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten.
"""
module_name_list = defaultdict(list)
for name, module in graph_module.named_modules(remove_duplicate=False):
module_name_list[module].append(name)

for name, new_observer in name_obs_dict.items():
old_module = getattr(graph_module, name, None)

if not old_module:
print(
f"[WARNING], No observer named as {name} found, please check the moudle name"
)
continue
for target_name in module_name_list[old_module]:
setattr(graph_module, target_name, new_observer)
Loading