Skip to content

Commit f1ef702

Browse files
authored
Qualcomm AI Engine Direct - Flags for CI (#9536)
### Summary Initially, `--compile_only` will use 1 random input for CI purpose. Introducing `--ci` flag, so when users are using `--compile_only`, they won't be getting models with poor accuracy.
1 parent ef30b25 commit f1ef702

File tree

11 files changed

+63
-40
lines changed

11 files changed

+63
-40
lines changed

.ci/scripts/test_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ test_model_with_qnn() {
201201
# TODO(guangyang): Make QNN chipset matches the target device
202202
QNN_CHIPSET=SM8450
203203

204-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only $EXTRA_FLAGS
204+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
205205
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
206206
}
207207

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@
3131
class AnnotateQuantAttrs(ExportPass):
3232
"""
3333
Add "quant_attrs" to graph nodes' meta from the QDQ information
34-
generated after quatization process.
34+
generated after quantization process.
3535
"""
3636

37-
def __init__(
38-
self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool
39-
):
37+
def __init__(self, edge_program: torch.export.ExportedProgram):
4038
super(AnnotateQuantAttrs, self).__init__()
4139
self.edge_program = edge_program
42-
self.skip_advanced_requant = skip_advanced_requat
4340

4441
def _annotate_source_nodes(
4542
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
@@ -88,30 +85,21 @@ def _annotate_requant(self, n):
8885
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
8986
# TODO: Store multiple pairs of requantize attributes when we have an op builder
9087
# that has multiple outputs that requires quant attributes.
91-
if self.skip_advanced_requant:
92-
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
93-
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
94-
user_node = list(dq_node.users)[0]
95-
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
96-
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
97-
else:
98-
# When dtype is the same but other specs such as scale and offset are different,
99-
# insert requant to improve accuracy.
100-
# Users can turn this feature off if any inference speed drop is observed.
101-
if any(
102-
q_attrs[attr] != dq_attrs[attr]
103-
for attr in [
104-
QCOM_SCALE,
105-
QCOM_ZERO_POINT,
106-
QCOM_QUANT_MIN,
107-
QCOM_QUANT_MAX,
108-
QCOM_DTYPE,
109-
]
110-
):
111-
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
112-
user_node = list(dq_node.users)[0]
113-
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
114-
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
88+
89+
if any(
90+
q_attrs[attr] != dq_attrs[attr]
91+
for attr in [
92+
QCOM_SCALE,
93+
QCOM_ZERO_POINT,
94+
QCOM_QUANT_MIN,
95+
QCOM_QUANT_MAX,
96+
QCOM_DTYPE,
97+
]
98+
):
99+
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
100+
user_node = list(dq_node.users)[0]
101+
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
102+
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
115103

116104
# Dequant all the fold_quant parameters back to fp32.
117105
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.

examples/qualcomm/scripts/deeplab_v3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
import random
1011
import re
@@ -74,8 +75,11 @@ def main(args):
7475
)
7576

7677
data_num = 100
77-
if args.compile_only:
78+
if args.ci:
7879
inputs = [(torch.rand(1, 3, 224, 224),)]
80+
logging.warning(
81+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
82+
)
7983
else:
8084
inputs, targets, input_list = get_dataset(
8185
data_size=data_num, dataset_dir=args.artifact, download=args.download

examples/qualcomm/scripts/edsr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
import re
1011
from multiprocessing.connection import Client
@@ -113,8 +114,11 @@ def main(args):
113114
)
114115

115116
instance = EdsrModel()
116-
if args.compile_only:
117+
if args.ci:
117118
inputs = instance.get_example_inputs()
119+
logging.warning(
120+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
121+
)
118122
else:
119123
dataset = get_dataset(
120124
args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact

examples/qualcomm/scripts/inception_v3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from multiprocessing.connection import Client
1011

@@ -37,8 +38,11 @@ def main(args):
3738
)
3839

3940
data_num = 100
40-
if args.compile_only:
41+
if args.ci:
4142
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
4246
else:
4347
inputs, targets, input_list = get_imagenet_dataset(
4448
dataset_path=f"{args.dataset}",

examples/qualcomm/scripts/inception_v4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from multiprocessing.connection import Client
1011

@@ -37,8 +38,11 @@ def main(args):
3738
)
3839

3940
data_num = 100
40-
if args.compile_only:
41-
inputs = [(torch.rand(1, 3, 299, 299),)]
41+
if args.ci:
42+
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
4246
else:
4347
inputs, targets, input_list = get_imagenet_dataset(
4448
dataset_path=f"{args.dataset}",

examples/qualcomm/scripts/mobilenet_v2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from multiprocessing.connection import Client
1011

@@ -37,8 +38,11 @@ def main(args):
3738
)
3839

3940
data_num = 100
40-
if args.compile_only:
41+
if args.ci:
4142
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
4246
else:
4347
inputs, targets, input_list = get_imagenet_dataset(
4448
dataset_path=f"{args.dataset}",

examples/qualcomm/scripts/mobilenet_v3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from multiprocessing.connection import Client
1011

@@ -36,8 +37,11 @@ def main(args):
3637
)
3738

3839
data_num = 100
39-
if args.compile_only:
40+
if args.ci:
4041
inputs = [(torch.rand(1, 3, 224, 224),)]
42+
logging.warning(
43+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
44+
)
4145
else:
4246
inputs, targets, input_list = get_imagenet_dataset(
4347
dataset_path=f"{args.dataset}",

examples/qualcomm/scripts/torchvision_vit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from multiprocessing.connection import Client
1011

@@ -28,8 +29,11 @@ def main(args):
2829
os.makedirs(args.artifact, exist_ok=True)
2930

3031
data_num = 100
31-
if args.compile_only:
32+
if args.ci:
3233
inputs = [(torch.rand(1, 3, 224, 224),)]
34+
logging.warning(
35+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
36+
)
3337
else:
3438
inputs, targets, input_list = get_imagenet_dataset(
3539
dataset_path=f"{args.dataset}",

examples/qualcomm/scripts/wav2letter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ def main(args):
134134

135135
# retrieve dataset, will take some time to download
136136
data_num = 100
137-
if args.compile_only:
137+
if args.ci:
138138
inputs = [(torch.rand(1, 1, 700, 1),)]
139139
logging.warning(
140-
"With compile_only, accuracy will be bad due to insufficient datasets for quantization."
140+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
141141
)
142142
else:
143143
inputs, targets, input_list = get_dataset(

examples/qualcomm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,13 @@ def setup_common_args_and_variables():
585585
action="store_true",
586586
)
587587

588+
parser.add_argument(
589+
"--ci",
590+
help="This flag is for Continuous Integration(CI) purpose and is NOT recommended to turn on for typical use cases. It will use random inputs instead of real inputs.",
591+
action="store_true",
592+
default=False,
593+
)
594+
588595
# QNN_SDK_ROOT might also be an argument, but it is used in various places.
589596
# So maybe it's fine to just use the environment.
590597
if "QNN_SDK_ROOT" not in os.environ:

0 commit comments

Comments
 (0)