Skip to content

Qualcomm AI Engine Direct - Add submodule quant config setting #9355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/decompose_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.experimental.proxy_tensor import make_fx

from .utils import copy_nn_module_stack


class DecomposeEinsum(ExportPass):
"""
Expand Down Expand Up @@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
remap[f"arg1_{i+1}"] = arg

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# This is the arg[0] equation string, which is not required anymore after decomposition
if "arg0" in decomposed_node.name:
continue
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/decompose_linalg_vector_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from executorch.exir import to_edge
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_nn_module_stack


class LinalgVectorNorm(torch.nn.Module):
def __init__(self, exp, dim, keepdim):
Expand Down Expand Up @@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
remap = {"x": node.args[0]}

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ def get_passes_dependency_for_capture_program():
}


def copy_nn_module_stack(src, target):
"""
Copy meta["nn_module_stack"] from src node to target node if existing.
"""
if value := src.meta.get("nn_module_stack"):
target.meta["nn_module_stack"] = value


def is_float_tensor(node: torch.fx.Node) -> bool:
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return False
Expand Down
219 changes: 143 additions & 76 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import IntEnum, unique
from functools import partial
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple

import torch
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
Expand Down Expand Up @@ -58,7 +59,7 @@ class QuantDtype(IntEnum):
use_8a8w = 4


quant_config_dict = {
QUANT_CONFIG_DICT = {
# PTQ
(QuantDtype.use_16a16w, False): (
get_16a16w_qnn_ptq_config,
Expand Down Expand Up @@ -123,21 +124,71 @@ class QuantDtype(IntEnum):
}


@dataclass
class ModuleQConfig:
quant_dtype: QuantDtype = QuantDtype.use_8a8w
is_qat: bool = False
is_conv_per_channel: bool = False
is_linear_per_channel: bool = False
act_observer: Optional[
torch.ao.quantization.observer.UniformQuantizationObserverBase
] = None

def __post_init__(self):
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
raise RuntimeError(
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
)
(
quant_config_func,
per_channel_quant_config_func,
per_block_quant_config_func,
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
self.quant_config = (
quant_config_func(act_observer=self.act_observer)
if self.act_observer
else quant_config_func()
)
self.per_channel_quant_config = (
per_channel_quant_config_func(act_observer=self.act_observer)
if self.act_observer
else per_channel_quant_config_func()
)
self.use_per_channel_weight_quant_ops = set()
if self.is_conv_per_channel:
self.use_per_channel_weight_quant_ops.update(
{
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
}
)
if self.is_linear_per_channel:
self.use_per_channel_weight_quant_ops.update(
{
torch.ops.aten.linear.default,
}
)
if per_block_quant_config_func:
self.per_block_quant_config = (
per_block_quant_config_func(act_observer=self.act_observer)
if self.act_observer
else per_block_quant_config_func()
)


class QnnQuantizer(Quantizer):
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())

def __init__(self):
super().__init__()
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()

self.is_qat = False
self.quant_dtype = QuantDtype.use_8a8w
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
self.per_block_quant_config = get_ptq_per_block_quant_config()
self.default_quant_config = ModuleQConfig()
self.submodule_qconfig_list: List[
Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig]
] = []
self.block_size_map = {}
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
self.use_per_block_weight_quant_ops: Set[OpOverload] = set()

self.custom_quant_annotations: Sequence[Callable] = []
self.discard_nodes: Set[str] = set()
Expand All @@ -155,41 +206,38 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
for annotation_func in self.custom_quant_annotations:
annotation_func(gm)

def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]:
def _get_submodule_qconfig(self, node: torch.fx.Node):
for func, qconfig in self.submodule_qconfig_list:
if func(node):
return qconfig
return self.default_quant_config

def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
"""
Priority:
1. is one of use_per_block_weight_quant_ops
2. is one of use_per_channel_weight_quant_ops
3. quant config
How to pick:
1. is one of per_block_quant_config
2. Pick specific submodule config if given.
3. Pick one if op belongs to use_per_channel_weight_quant_ops
4. If not 3, pick normal quant config
"""
target = op.target
if isinstance(target, str):
op = node.target
if isinstance(op, str):
return

if target in self.use_per_block_weight_quant_ops:
if block_size := self.block_size_map.get(op.name):
self.per_block_quant_config.block_size = block_size
return self.per_block_quant_config
if block_size := self.block_size_map.get(node.name):
config = self.default_quant_config.per_block_quant_config
config.block_size = block_size
return config

if target in self.use_per_channel_weight_quant_ops:
return self.per_channel_quant_config
config = self._get_submodule_qconfig(node)

if target in self.quant_ops:
return self.quant_config
if op in config.use_per_channel_weight_quant_ops:
return config.per_channel_quant_config

print(f"No quant config is implemented for op, {op}")

def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
if enable:
self.use_per_block_weight_quant_ops.update(ops)
else:
self.use_per_block_weight_quant_ops.difference_update(ops)
if op in self.quant_ops:
return config.quant_config

def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
if enable:
self.use_per_channel_weight_quant_ops.update(ops)
else:
self.use_per_channel_weight_quant_ops.difference_update(ops)
print(f"No quant config is implemented for op, {op}")

def add_custom_quant_annotations(
self, custom_quant_annotations: Sequence[Callable]
Expand All @@ -212,55 +260,74 @@ def annotate(self, model: GraphModule) -> GraphModule:
def get_supported_ops(self) -> Set[OpOverload]:
return self.SUPPORTED_OPS

def set_quant_config(
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
def set_default_quant_config(
self,
quant_dtype: QuantDtype,
is_qat=False,
is_conv_per_channel=False,
is_linear_per_channel=False,
act_observer=None,
) -> None:
self.quant_dtype = quant_dtype
self.is_qat = is_qat
if (quant_dtype, is_qat) not in quant_config_dict:
raise RuntimeError(
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
)

quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = (
quant_config_dict[(quant_dtype, is_qat)]
)
self.quant_config = (
quant_config_fuc(act_observer=act_observer)
if act_observer
else quant_config_fuc()
self.default_quant_config = ModuleQConfig(
quant_dtype,
is_qat,
is_conv_per_channel,
is_linear_per_channel,
act_observer,
)
self.per_channel_quant_config = (
per_channel_quant_config_fuc(act_observer=act_observer)
if act_observer
else per_channel_quant_config_fuc()
)
if per_block_quant_config_fuc is not None:
self.per_block_quant_config = (
per_block_quant_config_fuc(act_observer=act_observer)
if act_observer
else per_block_quant_config_fuc()
)

def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
self.block_size_map = block_size_map

def set_per_block_conv_quant(self, enable: bool) -> None:
conv_ops = {torch.ops.aten.conv2d.default}
self._update_per_block_weight_quant_ops(conv_ops, enable)

def set_per_channel_conv_quant(self, enable: bool) -> None:
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
self._update_per_channel_weight_quant_ops(conv_ops, enable)

def set_per_channel_linear_quant(self, enable: bool) -> None:
linear_ops = {
torch.ops.aten.linear.default,
}
self._update_per_channel_weight_quant_ops(linear_ops, enable)
def set_submodule_qconfig_list(
self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]]
) -> None:
"""
Set specific quant config from a callback function.
If a node fits more than one callback, only apply the first one.
"""
self.submodule_qconfig_list = submodule_qconfig_list

def transform_for_annotation(self, model: GraphModule) -> GraphModule:
return QnnPassManager().transform_for_annotation_pipeline(model)

def validate(self, model: GraphModule) -> None:
pass


def get_submodule_type_predicate(module_type_str):
"""
An example of nn_module_stack
{
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
}
"""

def predicate(node):
if nn_module_stack := node.meta.get("nn_module_stack"):
for _, type_name in nn_module_stack.values():
if module_type_str in type_name:
return True
return False

return predicate


def get_submodule_name_predicate(module_name_str):
"""
An example of nn_module_stack
{
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
}
"""

def predicate(node):
if nn_module_stack := node.meta.get("nn_module_stack"):
for name in nn_module_stack.keys():
if module_name_str in name:
return True
return False

return predicate
12 changes: 12 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,18 @@ def forward(self, x):
return 10 - x


class SimpleSubModules(torch.nn.Module):
def __init__(self):
super().__init__()
self.add = Add()
self.sub = Sub()

def forward(self, a, b, c, d):
lhs = self.add(a, b)
rhs = self.sub(c, d)
return torch.mul(lhs, rhs)


class SumIntList(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading