Skip to content

Commit 57bda67

Browse files
committed
skip rms norm
1 parent 5e0e947 commit 57bda67

File tree

9 files changed

+104
-7
lines changed

9 files changed

+104
-7
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def get_quant_encoding_conf(
164164
else node.meta["quant_attrs"]
165165
)
166166
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
167+
print(f"[Hutton define_tensor] {node.name} {quant_attrs['scales']}, {-quant_attrs['zero_points']}")
167168
return self.make_qnn_per_channel_config(node, quant_attrs)
168-
169+
print(f"[Hutton define_tensor] {node.name} {quant_attrs['scale']}, {-quant_attrs['zero_point']}")
169170
return self.make_qnn_per_tensor_config(quant_attrs)
170171

171172
def get_quant_tensor_value(

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def define_node(
2727
node: torch.fx.Node,
2828
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2929
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
3031
input_node = node.args[0]
32+
print(f"[Hutton] {node.name} {node.meta}")
3133
input_tensor = self.get_tensor(input_node, node)
3234
input_tensor_wrapper = self.define_tensor(
3335
input_node,

backends/qualcomm/builders/op_mul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def define_node(
2525
node: torch.fx.Node,
2626
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2727
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
print(f"[Hutton] {node.name} {node.meta}")
2829
out_tensor = self.get_tensor(node, node)
2930
output_tensor_wrapper = self.define_tensor(
3031
node,

backends/qualcomm/partition/common_defs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
exir_ops.edge.aten.full.default,
1515
exir_ops.edge.aten.index.Tensor,
1616
exir_ops.edge.aten.index_put.default,
17+
# exir_ops.edge.aten.mul.Tensor,
18+
# exir_ops.edge.aten.sub.Tensor,
19+
# exir_ops.edge.aten.add.Tensor,
20+
# exir_ops.edge.aten.rsqrt.default,
21+
# exir_ops.edge.aten.matmul.default,
22+
# exir_ops.edge.aten.unsqueeze_copy.default,
1723
]
1824

1925
allow_list_operator = [

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
from typing import Any, Dict, List
88

9+
from executorch.examples.models.llama2.llama_transformer import RMSNorm
910
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
1011
import torch
1112
from executorch.backends.qualcomm.builders import node_visitor
@@ -53,6 +54,7 @@ def __init__(
5354
self.qnn_manager = PyQnnManager.QnnManager(
5455
generate_qnn_executorch_option(compiler_specs)
5556
)
57+
self.discard_modules = set([RMSNorm])
5658

5759
self.qnn_manager.Init()
5860

@@ -62,7 +64,15 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
6264

6365
if node.target in allow_list_operator:
6466
return True
65-
67+
# if "nn_module_stack" in node.meta:
68+
# module_values_list = list(node.meta["nn_module_stack"].values())
69+
# owning_module = module_values_list[-1][1]
70+
# if owning_module in self.discard_modules:
71+
# print(f"[QNN Partitioner Op Support]: {node.name} | Skipped since RMS norm")
72+
# return False
73+
# if "quant_attrs" in node.meta and node.meta['quant_attrs']['scale'] > 1:
74+
# print(f"[QNN Partitioner Op Support]: {node.name} | Skipped since scale is greater than 1")
75+
# return False
6676
if self.skip_node_id_set is not None and node.name in self.skip_node_id_set:
6777
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
6878
return False

backends/qualcomm/quantizer/quantizer.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from enum import IntEnum, unique
77
from typing import Callable, Dict, Optional, Sequence, Set
88

9+
from executorch.examples.models.llama2.llama_transformer import RMSNorm
910
import torch
1011
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
1112
from executorch.backends.qualcomm.passes.decompose_scaled_dot_product_attention import (
@@ -62,7 +63,7 @@ def __init__(self):
6263

6364
self.custom_quant_annotations: Sequence[Callable] = []
6465
self.discard_nodes: Set[str] = set()
65-
66+
self.discard_modules: Set[torch.nn.Module] = set()
6667
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
6768
# the weight quantized for activation 8 bits and 16 bits
6869
self.per_channel_weight_dtype: Dict = {
@@ -71,10 +72,15 @@ def __init__(self):
7172
}
7273

7374
def _annotate(self, gm: GraphModule) -> None:
75+
self.discard_modules = set([RMSNorm])
7476
for node in gm.graph.nodes:
7577
if node.name in self.discard_nodes:
7678
continue
77-
79+
if "nn_module_stack" in node.meta:
80+
module_values_list = list(node.meta["nn_module_stack"].values())
81+
owning_module = module_values_list[-1][1]
82+
if owning_module in self.discard_modules:
83+
continue
7884
quant_config = self._get_quant_config(node.target)
7985
if quant_config:
8086
OP_ANNOTATOR[node.target](node, quant_config)
@@ -176,6 +182,50 @@ def set_per_channel_linear_quant(self, enable: bool) -> None:
176182
torch.ops.aten.linear.default,
177183
}
178184
self._update_per_channel_weight_quant_ops(linear_ops, enable)
185+
186+
def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None:
187+
"""
188+
For the case like mul(x, 2), convert the the scalr to tensor
189+
"""
190+
for n in gm.graph.nodes:
191+
if n.op != "call_function" or n.target not in (
192+
torch.ops.aten.add.Tensor,
193+
torch.ops.aten.sub.Tensor,
194+
torch.ops.aten.mul.Tensor,
195+
torch.ops.aten.mul.Scalar,
196+
torch.ops.aten.rsub.Scalar,
197+
):
198+
continue
199+
200+
const_arg = None
201+
non_const_arg = None
202+
for arg in n.args:
203+
if isinstance(arg, torch.fx.Node):
204+
non_const_arg = arg
205+
else:
206+
const_arg = arg
207+
208+
if non_const_arg is None or const_arg is None:
209+
continue
210+
211+
tensor_constant = torch.tensor([const_arg])
212+
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(
213+
gm
214+
)
215+
gm.register_buffer(tensor_constant_name, tensor_constant)
216+
217+
fake_mode = n.meta["val"].fake_mode
218+
with gm.graph.inserting_before(n):
219+
get_attr_node = gm.graph.get_attr(tensor_constant_name)
220+
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)
221+
222+
if n.target == torch.ops.aten.rsub.Scalar:
223+
n.args = (get_attr_node, non_const_arg) + n.args[2:]
224+
n.target = torch.ops.aten.sub.Tensor
225+
else:
226+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
227+
228+
gm.recompile()
179229

180230
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
181231
model = RemoveClone()(model).graph_module
@@ -184,7 +234,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
184234
model = DecomposeScaledDotProductAttention()(model).graph_module
185235
model = DecomposeSilu()(model).graph_module
186236
model = ReplaceInfBuffer()(model).graph_module
187-
237+
# self._lift_constant_scalar_operands(model)
188238
return model
189239

190240
def validate(self, model: GraphModule) -> None:

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def generate_htp_compiler_spec(
183183
htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst
184184
htp_options.use_multi_contexts = use_multi_contexts
185185
htp_options.use_dlbc = use_dlbc
186+
# htp_options.use_conv_hmx = False
186187
return QnnExecuTorchBackendOptions(
187188
backend_type=QnnExecuTorchBackendType.kHtpBackend,
188189
htp_options=htp_options,

examples/qualcomm/llama2/llama.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@
2424

2525
from sentencepiece import SentencePieceProcessor
2626

27+
def annotate_rms_norm(gm: torch.fx.GraphModule) -> None:
28+
from executorch.backends.qualcomm.quantizer.quantizer import (
29+
get_default_16bit_qnn_ptq_config,
30+
)
31+
from executorch.backends.qualcomm.quantizer.utils import (
32+
OP_ANNOTATOR,
33+
)
34+
from executorch.examples.models.llama2.llama_transformer import RMSNorm
35+
36+
quantization_config = get_default_16bit_qnn_ptq_config()
37+
SUPPORTED_OPS = set(OP_ANNOTATOR.keys())
38+
for node in gm.graph.nodes:
39+
if "nn_module_stack" in node.meta:
40+
module_values_list = list(node.meta["nn_module_stack"].values())
41+
owning_module = module_values_list[-1][1]
42+
if owning_module in [RMSNorm]:
43+
if node.target in SUPPORTED_OPS:
44+
print(f"[16 bits quant] {node.name}")
45+
OP_ANNOTATOR[node.target](node, quantization_config)
46+
2747

2848
def create_device_inputs(example_inputs):
2949
# TODO: support batch inputs if necessary
@@ -194,7 +214,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
194214
args.model,
195215
f"{args.artifact}/{pte_filename}",
196216
partial(calibrate, inputs),
197-
custom_annotations=(),
217+
custom_annotations=(annotate_rms_norm,),
198218
quant_dtype=quant_dtype,
199219
per_channel_linear=per_channel_linear,
200220
shared_buffer=args.shared_buffer,

examples/qualcomm/scripts/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def build_executorch_binary(
196196
raise AssertionError(f"No support for QuantDtype {quant_dtype}.")
197197

198198
captured_model = torch._export.capture_pre_autograd_graph(model, inputs)
199+
# from executorch.backends.qualcomm.utils.utils import draw_graph
200+
# draw_graph("before_quantized", ".", captured_model)
199201
annotated_model = prepare_pt2e(captured_model, quantizer)
200202
print("Quantizing the model...")
201203
# calibration
@@ -205,7 +207,8 @@ def build_executorch_binary(
205207
for data in dataset:
206208
annotated_model(*data)
207209
quantized_model = convert_pt2e(annotated_model)
208-
210+
# from executorch.backends.qualcomm.utils.utils import draw_graph
211+
# draw_graph("afte_quantized", ".", quantized_model)
209212
edge_prog = capture_program(quantized_model, inputs)
210213
else:
211214
edge_prog = capture_program(model, inputs)
@@ -261,6 +264,9 @@ def build_executorch_binary(
261264
compile_config=EdgeCompileConfig(_check_ir_validity=False),
262265
)
263266
edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner)
267+
# from executorch.backends.qualcomm.utils.utils import draw_graph
268+
# draw_graph("afte_pte_default_8a8w_rmsnorm_16a16w", ".", edge_prog_mgr.exported_program().graph_module)
269+
264270
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
265271
with open(f"{file_name}.pte", "wb") as file:
266272
file.write(exec_prog_mgr.buffer)

0 commit comments

Comments
 (0)