Skip to content

Commit 21baa73

Browse files
committed
tmp
1 parent 57bda67 commit 21baa73

File tree

17 files changed

+151
-72
lines changed

17 files changed

+151
-72
lines changed

backends/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
323323
pybind11::module
324324
pybind11::lto
325325
wrappers
326+
qnn_schema
326327
qnn_executorch_logging
327328
qnn_executorch_header
328329
)

backends/qualcomm/builders/node_visitor.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# Note that there is no int64 tensor data type in Qnn.
2727
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
2828
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
29-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
29+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
3232
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
@@ -35,7 +35,7 @@
3535
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
3636
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
3737
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
38-
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
38+
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
3939
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
4040
}
4141

@@ -164,31 +164,26 @@ def get_quant_encoding_conf(
164164
else node.meta["quant_attrs"]
165165
)
166166
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
167-
print(f"[Hutton define_tensor] {node.name} {quant_attrs['scales']}, {-quant_attrs['zero_points']}")
168167
return self.make_qnn_per_channel_config(node, quant_attrs)
169-
print(f"[Hutton define_tensor] {node.name} {quant_attrs['scale']}, {-quant_attrs['zero_point']}")
170168
return self.make_qnn_per_tensor_config(quant_attrs)
171169

172170
def get_quant_tensor_value(
173-
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth
171+
self, tensor: torch.Tensor, quant_attrs: Dict, dtype, quant_configs
174172
) -> torch.Tensor:
175173
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
176174
scale = quant_attrs["scale"]
177175
zero_point = quant_attrs["zero_point"]
176+
178177
else: # per channel case
179178
scale = quant_attrs["scales"]
180179
zero_point = quant_attrs["zero_points"]
181180

182181
# To bypass torch.uint16 quantization is not supported
183-
dtype = (
184-
torch.int32
185-
if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16
186-
else quant_attrs["dtype"]
187-
)
188-
182+
dtype = quant_configs["dtype"]
183+
print(f"[Hutton get_quant_tensor_value] tensor {tensor}")
189184
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
190185
# Make the backends access data correctly
191-
if bitwidth == 4:
186+
if quant_configs.get("bitwidth") == 4:
192187
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
193188
tensor = torch.bitwise_and(mask, tensor)
194189
return tensor
@@ -237,11 +232,12 @@ def get_data_type(
237232
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
238233
):
239234
if unsigned:
240-
quant_config["dtype"] = QNN_uint16
235+
quant_config["dtype"] = torch.uint16
241236
else:
242237
quant_config["dtype"] = torch.int16
243238
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
244239
else:
240+
print(f"[Hutton] QQ {tensor}")
245241
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
246242

247243
def define_custom_tensor_wrapper(
@@ -312,6 +308,7 @@ def define_tensor(
312308
)
313309
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
314310
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
311+
print(f"[Hutton fake_tensor] node {node} dtype {dtype} quant_configs[dtype] {quant_configs.get('dtype')}")
315312
tensor_wrapper = PyQnnWrapper.TensorWrapper(
316313
tensor_name,
317314
tensor_type,
@@ -329,8 +326,10 @@ def define_tensor(
329326
tensor,
330327
node.meta["quant_attrs"],
331328
dtype,
332-
quant_configs.get("bitwidth"),
329+
quant_configs,
333330
)
331+
print(f"[Hutton scalar] node {node}: dtype {dtype} quant_configs[dtype] {quant_configs['dtype']}")
332+
334333
tensor_wrapper = PyQnnWrapper.TensorWrapper(
335334
tensor_name,
336335
tensor_type,

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def define_node(
2929
) -> PyQnnWrapper.PyQnnOpWrapper:
3030

3131
input_node = node.args[0]
32-
print(f"[Hutton] {node.name} {node.meta}")
3332
input_tensor = self.get_tensor(input_node, node)
3433
input_tensor_wrapper = self.define_tensor(
3534
input_node,

backends/qualcomm/builders/op_mul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def define_node(
2525
node: torch.fx.Node,
2626
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2727
) -> PyQnnWrapper.PyQnnOpWrapper:
28-
print(f"[Hutton] {node.name} {node.meta}")
2928
out_tensor = self.get_tensor(node, node)
3029
output_tensor_wrapper = self.define_tensor(
3130
node,

backends/qualcomm/partition/common_defs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
exir_ops.edge.aten.full.default,
1515
exir_ops.edge.aten.index.Tensor,
1616
exir_ops.edge.aten.index_put.default,
17+
exir_ops.edge.aten.embedding.default,
18+
# exir_ops.edge.aten.addmm.default,
19+
# exir_ops.edge.aten.mm.default,
1720
# exir_ops.edge.aten.mul.Tensor,
1821
# exir_ops.edge.aten.sub.Tensor,
1922
# exir_ops.edge.aten.add.Tensor,

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import copy
77
from typing import Any, Dict, List
88

9-
from executorch.examples.models.llama2.llama_transformer import RMSNorm
109
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
1110
import torch
1211
from executorch.backends.qualcomm.builders import node_visitor
@@ -28,7 +27,7 @@
2827

2928
from .common_defs import allow_list_operator, not_supported_operator
3029

31-
30+
test = 0
3231
class QnnOperatorSupport(OperatorSupportBase):
3332
def __init__(
3433
self,
@@ -54,8 +53,8 @@ def __init__(
5453
self.qnn_manager = PyQnnManager.QnnManager(
5554
generate_qnn_executorch_option(compiler_specs)
5655
)
57-
self.discard_modules = set([RMSNorm])
58-
56+
# from executorch.examples.models.llama2.llama_transformer import RMSNorm
57+
self.discard_modules = ["executorch.examples.models.llama2.llama_transformer.RMSNorm"] #["executorch.examples.models.llama2.llama_transformer.RMSNorm"] # []
5958
self.qnn_manager.Init()
6059

6160
def is_node_supported(self, _, node: torch.fx.Node) -> bool:
@@ -64,15 +63,14 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
6463

6564
if node.target in allow_list_operator:
6665
return True
67-
# if "nn_module_stack" in node.meta:
68-
# module_values_list = list(node.meta["nn_module_stack"].values())
69-
# owning_module = module_values_list[-1][1]
70-
# if owning_module in self.discard_modules:
71-
# print(f"[QNN Partitioner Op Support]: {node.name} | Skipped since RMS norm")
72-
# return False
73-
# if "quant_attrs" in node.meta and node.meta['quant_attrs']['scale'] > 1:
74-
# print(f"[QNN Partitioner Op Support]: {node.name} | Skipped since scale is greater than 1")
75-
# return False
66+
global test
67+
test +=1
68+
if "nn_module_stack" in node.meta and test < 6:
69+
module_values_list = list(node.meta["nn_module_stack"].values())
70+
owning_module = module_values_list[-1][1]
71+
if owning_module in self.discard_modules:
72+
print(f"[QNN Partitioner Op Support]: {node.name} | Skipped since RMS norm")
73+
return False
7674
if self.skip_node_id_set is not None and node.name in self.skip_node_id_set:
7775
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
7876
return False

backends/qualcomm/qnn_preprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def preprocess(
5252

5353
pass_result = qnn_compiler_passes(edge_program.graph_module)
5454
assert pass_result is not None
55-
55+
# from executorch.backends.qualcomm.utils.utils import draw_graph
56+
# draw_graph("qnn_preprocess",".", pass_result.graph_module)
5657
enable_tensor_dump = qnn_manager.IsTensorDump()
5758
nodes_to_wrappers = {}
5859
node_visitors = get_node_visitors(

backends/qualcomm/quantizer/quantizer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch._ops import OpOverload
2121
from torch.ao.quantization.quantizer import Quantizer
2222
from torch.fx import GraphModule
23+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
2324

2425
from .utils import (
2526
get_16a4w_qnn_ptq_config,
@@ -75,12 +76,14 @@ def _annotate(self, gm: GraphModule) -> None:
7576
self.discard_modules = set([RMSNorm])
7677
for node in gm.graph.nodes:
7778
if node.name in self.discard_nodes:
79+
print(f"[Hutton quantizer.py] discard nodes {node.name}")
7880
continue
79-
if "nn_module_stack" in node.meta:
80-
module_values_list = list(node.meta["nn_module_stack"].values())
81-
owning_module = module_values_list[-1][1]
82-
if owning_module in self.discard_modules:
83-
continue
81+
# if "nn_module_stack" in node.meta:
82+
# module_values_list = list(node.meta["nn_module_stack"].values())
83+
# owning_module = module_values_list[-1][1]
84+
# if owning_module in self.discard_modules:
85+
# print(f"[Hutton quantizer.py] discard modules {node.name}")
86+
# continue
8487
quant_config = self._get_quant_config(node.target)
8588
if quant_config:
8689
OP_ANNOTATOR[node.target](node, quant_config)
@@ -207,8 +210,10 @@ def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None:
207210

208211
if non_const_arg is None or const_arg is None:
209212
continue
210-
211-
tensor_constant = torch.tensor([const_arg])
213+
if type(const_arg) is int:
214+
tensor_constant = torch.tensor([const_arg], dtype=torch.int32)
215+
else:
216+
tensor_constant = torch.tensor([const_arg])
212217
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(
213218
gm
214219
)

backends/qualcomm/quantizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_16a4w_qnn_ptq_config() -> QuantizationConfig:
120120
quant_min=torch.iinfo(torch.uint16).min,
121121
quant_max=torch.iinfo(torch.uint16).max,
122122
qscheme=torch.per_tensor_affine,
123-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
123+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
124124
)
125125

126126
weight_quantization_spec = QuantizationSpec(
@@ -157,7 +157,7 @@ def get_default_16bit_qnn_ptq_config() -> QuantizationConfig:
157157
quant_min=torch.iinfo(torch.uint16).min,
158158
quant_max=torch.iinfo(torch.uint16).max,
159159
qscheme=torch.per_tensor_affine,
160-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
160+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
161161
)
162162

163163
weight_quantization_spec = QuantizationSpec(

backends/qualcomm/runtime/backends/QnnBackendCache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ QnnBackendCache::QnnBackendCache(
8787
state_ = SERIALIZE;
8888
QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE.");
8989
return;
90-
} else {
90+
}
91+
/*else {
9192
// TODO: need fix on this since qnn context binary could somehow
9293
// pass the check of flatbuffer verifier
9394
// check if context binary came from flatbuffer
@@ -100,7 +101,7 @@ QnnBackendCache::QnnBackendCache(
100101
state_ = ONLINE_PREPARE;
101102
return;
102103
}
103-
}
104+
}*/
104105

105106
if (qnn_sys_impl_.Load() != Error::Ok) {
106107
QNN_EXECUTORCH_LOG_ERROR(

backends/qualcomm/tests/models.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def __init__(self):
461461
super().__init__()
462462

463463
def forward(self, x, y):
464-
return torch.mul(x, y)
464+
return x*y
465465

466466

467467
class MulConstantFloat(torch.nn.Module):
@@ -490,6 +490,16 @@ def forward(self, x):
490490
return out1
491491

492492

493+
class MulQQ(torch.nn.Module):
494+
def __init__(self):
495+
super().__init__()
496+
self.weight = torch.nn.Parameter(torch.ones(2))
497+
498+
def forward(self, x):
499+
output = x
500+
return output * self.weight
501+
502+
493503
class MultiheadAttention(torch.nn.Module):
494504
def __init__(self):
495505
super().__init__()
@@ -588,6 +598,21 @@ def forward(self, x):
588598
return torch.rsqrt(x)
589599

590600

601+
class RMSNorm(torch.nn.Module):
602+
def __init__(self, dim: int, eps: float = 1e-6):
603+
super().__init__()
604+
self.eps = eps
605+
self.weight = torch.nn.Parameter(torch.ones(dim))
606+
607+
# def forward(self, x):
608+
# return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
609+
def _norm(self, x):
610+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
611+
612+
def forward(self, x):
613+
output = self._norm(x.float()).type_as(x)
614+
return output * self.weight
615+
591616
class ScaledDotProductAttention(torch.nn.Module):
592617
def __init__(self):
593618
super().__init__()

0 commit comments

Comments
 (0)