Skip to content

Qualcomm AI Engine Direct - support skip quantization #5070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/qualcomm/builders/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_SCALE,
)

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand All @@ -21,6 +26,14 @@ class BatchNorm(NodeVisitor):
def __init__(self, *args) -> None:
super().__init__(*args)

def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor):
if isinstance(tensor, torch._subclasses.FakeTensor):
return

if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
diff = max(abs(tensor.max()), abs(tensor.min()))
quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX]

def define_node(
self,
node: torch.fx.Node,
Expand Down Expand Up @@ -48,6 +61,7 @@ def define_node(

amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
bias_tensor = bias_tensor - amount
self.update_encoding(bias_node, bias_tensor)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
Expand All @@ -57,6 +71,7 @@ def define_node(
)

filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps)
self.update_encoding(filter_node, filter_tensor)
filter_tensor_wrapper = self.define_tensor(
filter_node,
filter_tensor,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class Softmax(NodeVisitor):
target = ["aten._softmax.default"]
target = ["aten._softmax.default", "aten._safe_softmax.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
55 changes: 34 additions & 21 deletions backends/qualcomm/passes/annotate_and_quant_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.exir.passes import dead_code_elimination_pass
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

from .utils import get_quant_attrs
from .utils import dq_ops, get_quant_attrs


class AnnotateAndQuantScalar(ExportPass):
Expand Down Expand Up @@ -89,30 +89,43 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
graph_module.graph, self.binary_op_sources
)
src_partitions = list(itertools.chain(*src_partitions.values()))
processed = set()
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
if (
output.meta.get(QCOM_QUANT_ATTRS)
and len(src_partition.input_nodes) == 1
):
dq_node = src_partition.input_nodes[0]
q_node = dq_node.args[0]
q_node_attrs = get_quant_attrs(graph_module, q_node)

scalar_nodes = [n for n in output.args if n != dq_node]
if len(scalar_nodes) == 0:
# need post process here to identify partitioned nodes:
src_fn_dict = {}
for n in src_partition.nodes:
# e.g.
# meta["source_fn_stack"]: [('mul', <built-in function mul>)]
# we'll use <built-in function mul> as grouping key
node_list = src_fn_dict.setdefault(n.meta["source_fn_stack"][-1][1], [])
node_list.append(n)

for nodes in src_fn_dict.values():
output = [n for n in nodes if n in src_partition.output_nodes][0]
# if all args have been annotated, it shouldn't be a scalar operation
if all(arg.target in dq_ops for arg in output.args):
continue

scalar_node = scalar_nodes[0]
source_scalar_node = self._get_source_scalar_node(scalar_node)
# we'll abandon cast op here, since the constant scalar will
# be pre-loaded into QNN context binary
output.replace_input_with(scalar_node, source_scalar_node)
if output not in processed and QCOM_QUANT_ATTRS in output.meta:
dq_node = [n for n in output.args if n.target in dq_ops][0]
q_node = dq_node.args[0]
q_node_attrs = get_quant_attrs(graph_module, q_node)

scalar_nodes = [n for n in output.args if n != dq_node]
if len(scalar_nodes) == 0:
continue

scalar_node = scalar_nodes[0]
source_scalar_node = self._get_source_scalar_node(scalar_node)
# we'll abandon cast op here, since the constant scalar will
# be pre-loaded into QNN context binary
output.replace_input_with(scalar_node, source_scalar_node)

scalar_quant_attrs = self._update_scalar_node_attrs(
source_scalar_node, q_node_attrs
)
self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs)
scalar_quant_attrs = self._update_scalar_node_attrs(
source_scalar_node, q_node_attrs
)
self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs)
processed.add(output)

def call(self, graph_module: torch.fx.GraphModule):
self._traverse_binary_node(graph_module)
Expand Down
46 changes: 0 additions & 46 deletions backends/qualcomm/passes/recompose_pixel_shuffle.py

This file was deleted.

25 changes: 0 additions & 25 deletions backends/qualcomm/passes/recompose_pixel_unshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class RecomposePixelUnshuffle(ExportPass):
Expand Down Expand Up @@ -85,30 +84,6 @@ def call(self, graph_module: torch.fx.GraphModule):
# copy metadata
pixel_unshuffle_node.meta = node.meta

# decomposed core aten ops
if not self.quantization_capture:
partitions = get_source_partitions(graph, [torch.nn.PixelUnshuffle])
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
input_node = src_partition.input_nodes[0]
output_node = src_partition.output_nodes[0]
with graph.inserting_after(input_node):
h_in_shape = input_node.meta["val"].shape[2]
h_out_shape = output_node.meta["val"].shape[2]
downscale_factor = h_in_shape / h_out_shape

op = self.op
pixel_unshuffle_node = graph.create_node(
"call_function",
op,
(input_node, int(downscale_factor)),
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, pixel_unshuffle_node)
# copy metadata
pixel_unshuffle_node.meta = output_node.meta

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
47 changes: 43 additions & 4 deletions backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numbers
import operator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -77,7 +78,7 @@ def _derive_bias_qparams_fn(


def get_default_8bit_qnn_ptq_config(
act_symmetric: bool = False, act_observer=MinMaxObserver
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
) -> QuantizationConfig:
extra_args: Dict[str, Any] = {"eps": 2**-12}

Expand All @@ -96,15 +97,15 @@ def get_default_8bit_qnn_ptq_config(
quant_max=torch.iinfo(torch.int8).max,
qscheme=torch.per_tensor_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
)

bias_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
)

quantization_config = QuantizationConfig(
Expand Down Expand Up @@ -619,7 +620,13 @@ def annotate_upsample_nearest2d(
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.softmax.int, torch.ops.aten._softmax.default])
@register_annotator(
[
torch.ops.aten.softmax.int,
torch.ops.aten._softmax.default,
torch.ops.aten._safe_softmax.default,
]
)
def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)

Expand Down Expand Up @@ -1000,6 +1007,38 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]


@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default])
def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act, weight, bias = node.args[0:3]
if _is_annotated([node]):
return

_annotate_input_qspec_map(
node,
act,
quantization_config.input_activation,
)
# QNN requires uint8 instead of int8 in 'weight' config
_annotate_input_qspec_map(
node,
weight,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
bias,
quantization_config.bias,
)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node, *node.args[0:3]])


@register_annotator([operator.getitem])
def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None:
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node])


@register_annotator([torch.ops.aten.layer_norm.default])
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
Expand Down
10 changes: 10 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def forward(self, x):
return self.avgPool(x)


class BatchNorm(torch.nn.Module):
def __init__(self, n_features):
super().__init__()
self.native_batchnorm = torch.nn.BatchNorm2d(n_features)
self.eval()

def forward(self, x):
return self.native_batchnorm(x)


class Bmm(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading