Skip to content

Commit d174637

Browse files
haowhsu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - oss model enablement (fastvit) (#5543)
Summary: - e2e script for https://github.com/apple/ml-fastvit (fastvit_s18) - add pass to handle mismatched tensor shape for broadcast ops when doing layout transform - add ParamObserver for params with lots of outliers - refactor & breakage fix Pull Request resolved: #5543 Reviewed By: kimishpatel Differential Revision: D63965451 Pulled By: cccclai fbshipit-source-id: 40cd85f60a8a539e6600cac1bfe16cdac4bb0465
1 parent c06a708 commit d174637

30 files changed

+649
-469
lines changed

backends/qualcomm/aot/wrappers/TensorWrapper.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ TensorWrapper::TensorWrapper(
9191
if (data != nullptr) {
9292
QNN_VER_PTR(tensor_)->clientBuf.dataSize = bytes;
9393

94-
if (copy_data) {
94+
if (tensor_type != QNN_TENSOR_TYPE_STATIC) {
95+
QNN_VER_PTR(tensor_)->clientBuf.data = nullptr;
96+
} else if (copy_data) {
9597
owned_data_ = std::make_unique<char[]>(bytes);
9698
const char* src_data = static_cast<const char*>(data);
9799
std::memcpy(owned_data_.get(), src_data, bytes);

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def define_node(
5151
filter_size = filter_size + filter_size
5252
filter_size_shape = [len(filter_size)]
5353

54-
# stride info
55-
stride = cast(List[int], node.args[2])
54+
# stride info - default to kernel_size if not given
55+
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
5656
if len(stride) == 1:
5757
stride = stride + stride
5858
stride_shape = [len(stride)]

backends/qualcomm/passes/convert_to_linear.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,49 +109,50 @@ def _convert_to_linear(
109109

110110
# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
111111
# TODO: Find a more general conditional statement.
112-
if (
113-
fn_node.target == self.add
114-
and linear_node.meta["val"].dim() == 3
115-
and linear_node.meta["val"].shape[0] == 1
116-
):
117-
squeeze_dim = linear_node.meta["val"].shape[1:]
118-
linear_node.meta["val"] = torch.squeeze(linear_node.meta["val"], 0)
112+
linear_output = linear_node.meta["val"]
113+
if linear_output.dim() == 3 and linear_output.shape[0] == 1:
119114
with gm.graph.inserting_after(input_node):
120115
input_users = list(input_node.users.keys())
121-
squeeze_dim = linear_node.meta["val"].shape
122-
squeeze_view_copy_node = gm.graph.create_node(
116+
input_tensor = input_node.meta["val"]
117+
squeeze_dim = input_tensor.shape[-2:]
118+
squeeze_node = gm.graph.create_node(
123119
"call_function",
124120
self.view_copy,
125121
(
126122
input_node,
127123
squeeze_dim,
128124
),
129125
)
130-
squeeze_view_copy_node.meta = linear_node.meta
126+
# meta needs to be copied elementwisely for fake-tensor
127+
# to be updated correctly and not affect meta of input_node
128+
for k, v in input_node.meta.items():
129+
squeeze_node.meta[k] = v
130+
squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
131131
for user in input_users:
132132
if user == linear_node:
133-
user.replace_input_with(input_node, squeeze_view_copy_node)
134-
with gm.graph.inserting_after(output):
133+
user.replace_input_with(input_node, squeeze_node)
134+
135+
with gm.graph.inserting_after(linear_node):
135136
output_users = list(linear_node.users.keys())
136-
unsqueeze_dim = output.args[0].meta["val"].shape
137-
unsqueeze_view_copy_node = gm.graph.create_node(
137+
unsqueeze_dim = linear_output.shape
138+
unsqueeze_node = gm.graph.create_node(
138139
"call_function",
139140
self.view_copy,
140141
(
141142
linear_node,
142143
unsqueeze_dim,
143144
),
144145
)
145-
unsqueeze_view_copy_node.meta = output.args[0].meta
146+
# meta needs to be copied elementwisely for fake-tensor
147+
# to be updated correctly and not affect meta of unsqueeze_node
148+
for k, v in linear_node.meta.items():
149+
unsqueeze_node.meta[k] = v
150+
# update linear node's shape
151+
linear_node.meta["val"] = linear_output.reshape(
152+
linear_output.shape[-2:]
153+
)
146154
for user in output_users:
147-
user.replace_input_with(linear_node, unsqueeze_view_copy_node)
148-
if QCOM_QUANT_ATTRS in linear_node.meta:
149-
squeeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[
150-
QCOM_QUANT_ATTRS
151-
]
152-
unsqueeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[
153-
QCOM_QUANT_ATTRS
154-
]
155+
user.replace_input_with(linear_node, unsqueeze_node)
155156

156157
def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
157158
mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.passes import dead_code_elimination_pass
11+
12+
13+
class ExpandBroadcastTensorShape(ExportPass):
14+
"""
15+
Make tensors have same rank for layout-transform to work properly.
16+
"""
17+
18+
def __init__(self):
19+
super(ExpandBroadcastTensorShape, self).__init__()
20+
self.broadcast_op_targets = [
21+
exir_ops.edge.aten.add.Tensor,
22+
exir_ops.edge.aten.sub.Tensor,
23+
exir_ops.edge.aten.mul.Tensor,
24+
exir_ops.edge.aten.div.Tensor,
25+
]
26+
27+
def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
28+
for node in graph_module.graph.nodes:
29+
if node.target in self.broadcast_op_targets:
30+
for arg in node.args:
31+
input_rank = len(arg.meta["val"].shape)
32+
output_rank = len(node.meta["val"].shape)
33+
if input_rank != output_rank:
34+
with graph_module.graph.inserting_after(arg):
35+
new_rank = [1] * (output_rank - input_rank) + list(
36+
arg.meta["val"].shape
37+
)
38+
users = list(arg.users.keys())
39+
reshape_node = graph_module.graph.create_node(
40+
"call_function",
41+
exir_ops.edge.aten.view_copy.default,
42+
(arg, tuple(new_rank)),
43+
)
44+
# meta needs to be copied elementwisely for fake-tensor
45+
# to be updated correctly and not affect meta of arg
46+
for k, v in arg.meta.items():
47+
reshape_node.meta[k] = v
48+
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
49+
new_rank
50+
)
51+
for user in users:
52+
user.replace_input_with(arg, reshape_node)
53+
54+
def call(self, graph_module: torch.fx.GraphModule):
55+
self.traverse_broadcast_node(graph_module)
56+
graph_module.recompile()
57+
dead_code_elimination_pass(graph_module)
58+
return PassResult(graph_module, True)

backends/qualcomm/quantizer/quantizer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
get_16a8w_qnn_ptq_config,
2727
get_default_16bit_qnn_ptq_config,
2828
get_default_8bit_qnn_ptq_config,
29-
get_ptq_per_channel_weight_config,
29+
get_ptq_per_channel_quant_config,
3030
OP_ANNOTATOR,
3131
QuantizationConfig,
3232
)
@@ -72,6 +72,7 @@ def __init__(self):
7272
"8bit_act": torch.int8,
7373
"16bit_act": torch.int16,
7474
}
75+
self.per_channel_quant_config = None
7576

7677
def _annotate(self, gm: GraphModule) -> None:
7778
for node in gm.graph.nodes:
@@ -96,13 +97,17 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
9697
return
9798

9899
if op in self.use_per_channel_weight_quant_ops:
99-
if op in self.bit16_quant_ops:
100-
return get_ptq_per_channel_weight_config(
101-
torch.uint16, self.per_channel_weight_dtype["16bit_act"]
100+
if self.per_channel_quant_config is None:
101+
if op in self.bit16_quant_ops:
102+
return get_ptq_per_channel_quant_config(
103+
act_dtype=torch.uint16,
104+
weight_dtype=self.per_channel_weight_dtype["16bit_act"],
105+
)
106+
return get_ptq_per_channel_quant_config(
107+
act_dtype=torch.uint8,
108+
weight_dtype=self.per_channel_weight_dtype["8bit_act"],
102109
)
103-
return get_ptq_per_channel_weight_config(
104-
weight_dtype=self.per_channel_weight_dtype["8bit_act"]
105-
)
110+
return self.per_channel_quant_config
106111

107112
if op in self.bit8_quant_ops:
108113
return self.bit8_quant_config

backends/qualcomm/quantizer/utils.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MinMaxObserver,
2121
MovingAverageMinMaxObserver,
2222
PerChannelMinMaxObserver,
23+
UniformQuantizationObserverBase,
2324
)
2425

2526
from torch.ao.quantization.quantizer import (
@@ -35,6 +36,107 @@
3536
from torch.fx import Node
3637

3738

39+
class ParamObserver(UniformQuantizationObserverBase):
40+
def __init__(
41+
self,
42+
ch_axis=0,
43+
use_mse=True,
44+
steps=100,
45+
dtype=torch.int8,
46+
qscheme=torch.per_channel_symmetric,
47+
reduce_range=False,
48+
quant_min=None,
49+
quant_max=None,
50+
factory_kwargs=None,
51+
eps=torch.finfo(torch.float32).eps, # noqa: B008
52+
is_dynamic=False,
53+
**kwargs,
54+
) -> None:
55+
super().__init__(
56+
dtype=dtype,
57+
qscheme=qscheme,
58+
reduce_range=reduce_range,
59+
quant_min=quant_min,
60+
quant_max=quant_max,
61+
factory_kwargs=factory_kwargs,
62+
eps=eps,
63+
is_dynamic=is_dynamic,
64+
**kwargs,
65+
)
66+
67+
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
68+
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
69+
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
70+
self.ch_axis = ch_axis
71+
self.use_mse = use_mse
72+
self.steps = steps
73+
self.calibrated = False
74+
75+
def to_ch_axis(self, x):
76+
axis_order = list(range(len(x.size())))
77+
axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
78+
return torch.flatten(x.permute(axis_order), start_dim=1)
79+
80+
def mse(self, pred, expect):
81+
loss = (pred - expect).abs().pow(2)
82+
return self.to_ch_axis(loss).mean(1)
83+
84+
def cosine(self, pred, expect):
85+
target = torch.ones(pred.shape[self.ch_axis])
86+
pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
87+
expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
88+
return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)
89+
90+
def loss_fn(self, x, new_min, new_max):
91+
scale, offset = self._calculate_qparams(new_min, new_max)
92+
x_q = torch.fake_quantize_per_channel_affine(
93+
x,
94+
scale.data,
95+
offset.data.int(),
96+
self.ch_axis,
97+
self.quant_min,
98+
self.quant_max,
99+
)
100+
return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)
101+
102+
def line_search(self, x):
103+
x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
104+
x_range = torch.max(x_min.abs(), x_max)
105+
optimal_loss = torch.zeros_like(x_min) + 1e9
106+
107+
# check which clip range could produce smallest loss
108+
for i in range(1, self.steps + 1):
109+
thres = x_range / self.steps * i
110+
current_loss = self.loss_fn(x, -thres, thres)
111+
x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
112+
x_max = torch.where(current_loss < optimal_loss, thres, x_max)
113+
optimal_loss = torch.min(current_loss, optimal_loss)
114+
115+
return x_min, x_max
116+
117+
def forward(self, x_orig):
118+
# since params are static, one calibration is enough
119+
if not self.calibrated:
120+
x = x_orig.detach().to(self.min_val.dtype)
121+
self.min_val, self.max_val = self.line_search(x)
122+
self.calibrated = True
123+
124+
# return fake-quant result for saturating outliers
125+
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
126+
return torch.fake_quantize_per_channel_affine(
127+
x_orig,
128+
scale.data,
129+
zero_point.data.int(),
130+
self.ch_axis,
131+
self.quant_min,
132+
self.quant_max,
133+
)
134+
135+
@torch.jit.export
136+
def calculate_qparams(self):
137+
return self._calculate_qparams(self.min_val, self.max_val)
138+
139+
38140
@dataclass(eq=True, frozen=True)
39141
class QuantizationConfig:
40142
input_activation: Optional[QuantizationSpec]
@@ -235,7 +337,7 @@ def get_default_16bit_qnn_ptq_config(
235337
return quantization_config
236338

237339

238-
def get_ptq_per_channel_weight_config(
340+
def get_ptq_per_channel_quant_config(
239341
act_dtype=torch.uint8, weight_dtype=torch.int8
240342
) -> QuantizationConfig:
241343
extra_args: Dict[str, Any] = {"eps": 2**-12}
@@ -585,7 +687,7 @@ def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
585687
annotate_single_in_single_out(node, quantization_config)
586688

587689

588-
@register_annotator([torch.ops.aten.view.default])
690+
@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default])
589691
def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
590692
annotate_in_out_obs_sharing_op(node, quantization_config)
591693
if not _is_annotated([node]):

0 commit comments

Comments
 (0)