Skip to content

Commit b5e5686

Browse files
committed
Qualcomm AI Engine Direct - GA CvT
1 parent d8c26ee commit b5e5686

File tree

3 files changed

+249
-1
lines changed

3 files changed

+249
-1
lines changed

backends/qualcomm/quantizer/annotators.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,13 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None
11701170
)
11711171

11721172

1173-
@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
1173+
@register_annotator(
1174+
[
1175+
torch.ops.aten.split.Tensor,
1176+
torch.ops.aten.chunk.default,
1177+
torch.ops.aten.split_with_sizes.default,
1178+
]
1179+
)
11741180
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
11751181
if _is_annotated([node]):
11761182
return

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3679,6 +3679,42 @@ def test_conv_former(self):
36793679
self.assertGreaterEqual(msg["top_1"], 60)
36803680
self.assertGreaterEqual(msg["top_5"], 80)
36813681

3682+
def test_cvt(self):
3683+
if not self.required_envs([self.image_dataset]):
3684+
self.skipTest("missing required envs")
3685+
3686+
cmds = [
3687+
"python",
3688+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/cvt.py",
3689+
"--dataset",
3690+
self.image_dataset,
3691+
"--artifact",
3692+
self.artifact_dir,
3693+
"--build_folder",
3694+
self.build_folder,
3695+
"--device",
3696+
self.device,
3697+
"--model",
3698+
self.model,
3699+
"--ip",
3700+
self.ip,
3701+
"--port",
3702+
str(self.port),
3703+
]
3704+
if self.host:
3705+
cmds.extend(["--host", self.host])
3706+
3707+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3708+
with Listener((self.ip, self.port)) as listener:
3709+
conn = listener.accept()
3710+
p.communicate()
3711+
msg = json.loads(conn.recv())
3712+
if "Error" in msg:
3713+
self.fail(msg["Error"])
3714+
else:
3715+
self.assertGreaterEqual(msg["top_1"], 70)
3716+
self.assertGreaterEqual(msg["top_5"], 90)
3717+
36823718
def test_dino_v2(self):
36833719
if not self.required_envs([self.image_dataset]):
36843720
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/cvt.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
import types
11+
from multiprocessing.connection import Client
12+
13+
import numpy as np
14+
15+
import torch
16+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17+
from executorch.examples.qualcomm.utils import (
18+
build_executorch_binary,
19+
get_imagenet_dataset,
20+
make_output_dir,
21+
parse_skip_delegation_node,
22+
setup_common_args_and_variables,
23+
SimpleADB,
24+
topk_accuracy,
25+
)
26+
from transformers import AutoModelForImageClassification
27+
from transformers.models.cvt.modeling_cvt import CvtSelfAttention
28+
29+
30+
# Copy from transformers/models/cvt/modeling_cvt.py in transformers 4.47.1
31+
def attention_forward_without_einsum(self, hidden_state, height, width):
32+
if self.with_cls_token:
33+
cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
34+
batch_size, hidden_size, num_channels = hidden_state.shape
35+
# rearrange "b (h w) c -> b c h w"
36+
hidden_state = hidden_state.permute(0, 2, 1).view(
37+
batch_size, num_channels, height, width
38+
)
39+
40+
key = self.convolution_projection_key(hidden_state)
41+
query = self.convolution_projection_query(hidden_state)
42+
value = self.convolution_projection_value(hidden_state)
43+
44+
if self.with_cls_token:
45+
query = torch.cat((cls_token, query), dim=1)
46+
key = torch.cat((cls_token, key), dim=1)
47+
value = torch.cat((cls_token, value), dim=1)
48+
49+
head_dim = self.embed_dim // self.num_heads
50+
51+
query = self.rearrange_for_multi_head_attention(self.projection_query(query))
52+
key = self.rearrange_for_multi_head_attention(self.projection_key(key))
53+
value = self.rearrange_for_multi_head_attention(self.projection_value(value))
54+
# ====================Qualcomm Changed=================================
55+
attention_score = query @ key.transpose(-1, -2)
56+
attention_score = attention_score * self.scale
57+
# attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
58+
# =====================================================================
59+
attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
60+
attention_probs = self.dropout(attention_probs)
61+
# ====================Qualcomm Changed=================================
62+
context = attention_probs @ value
63+
# context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
64+
# =====================================================================
65+
# rearrange"b h t d -> b t (h d)"
66+
_, _, hidden_size, _ = context.shape
67+
context = (
68+
context.permute(0, 2, 1, 3)
69+
.contiguous()
70+
.view(batch_size, hidden_size, self.num_heads * head_dim)
71+
)
72+
return context
73+
74+
75+
def _replace_attention(
76+
module: torch.nn.Module,
77+
):
78+
for _, child in module.named_children():
79+
if isinstance(child, CvtSelfAttention):
80+
child.forward = types.MethodType( # pyre-ignore
81+
attention_forward_without_einsum, child
82+
)
83+
else:
84+
_replace_attention(child)
85+
return module
86+
87+
88+
def main(args):
89+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
90+
91+
# ensure the working directory exist.
92+
os.makedirs(args.artifact, exist_ok=True)
93+
94+
if not args.compile_only and args.device is None:
95+
raise RuntimeError(
96+
"device serial is required if not compile only. "
97+
"Please specify a device serial by -s/--device argument."
98+
)
99+
100+
data_num = 100
101+
if args.ci:
102+
inputs = [(torch.rand(1, 3, 224, 224),)]
103+
logging.warning(
104+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
105+
)
106+
else:
107+
inputs, targets, input_list = get_imagenet_dataset(
108+
dataset_path=f"{args.dataset}",
109+
data_size=data_num,
110+
image_shape=(256, 256),
111+
crop_size=224,
112+
)
113+
114+
module = (
115+
AutoModelForImageClassification.from_pretrained("microsoft/cvt-13")
116+
.eval()
117+
.to("cpu")
118+
)
119+
# Fix prepare failed due to einsum
120+
module = _replace_attention(module)
121+
pte_filename = "cvt_qnn_q8"
122+
build_executorch_binary(
123+
module.eval(),
124+
inputs[0],
125+
args.model,
126+
f"{args.artifact}/{pte_filename}",
127+
inputs,
128+
skip_node_id_set=skip_node_id_set,
129+
skip_node_op_set=skip_node_op_set,
130+
quant_dtype=QuantDtype.use_8a8w,
131+
shared_buffer=args.shared_buffer,
132+
)
133+
134+
if args.compile_only:
135+
return
136+
137+
adb = SimpleADB(
138+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
139+
build_path=f"{args.build_folder}",
140+
pte_path=f"{args.artifact}/{pte_filename}.pte",
141+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
142+
device_id=args.device,
143+
host_id=args.host,
144+
soc_model=args.model,
145+
shared_buffer=args.shared_buffer,
146+
)
147+
adb.push(inputs=inputs, input_list=input_list)
148+
adb.execute()
149+
150+
# collect output data
151+
output_data_folder = f"{args.artifact}/outputs"
152+
make_output_dir(output_data_folder)
153+
154+
adb.pull(output_path=args.artifact)
155+
156+
# top-k analysis
157+
predictions = []
158+
for i in range(data_num):
159+
predictions.append(
160+
np.fromfile(
161+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
162+
)
163+
)
164+
165+
k_val = [1, 5]
166+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
167+
if args.ip and args.port != -1:
168+
with Client((args.ip, args.port)) as conn:
169+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
170+
else:
171+
for i, k in enumerate(k_val):
172+
print(f"top_{k}->{topk[i]}%")
173+
174+
175+
if __name__ == "__main__":
176+
parser = setup_common_args_and_variables()
177+
178+
parser.add_argument(
179+
"-d",
180+
"--dataset",
181+
help=(
182+
"path to the validation folder of ImageNet dataset. "
183+
"e.g. --dataset imagenet-mini/val "
184+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
185+
),
186+
type=str,
187+
required=False,
188+
)
189+
190+
parser.add_argument(
191+
"-a",
192+
"--artifact",
193+
help="path for storing generated artifacts by this example. " "Default ./cvt",
194+
default="./cvt",
195+
type=str,
196+
)
197+
198+
args = parser.parse_args()
199+
try:
200+
main(args)
201+
except Exception as e:
202+
if args.ip and args.port != -1:
203+
with Client((args.ip, args.port)) as conn:
204+
conn.send(json.dumps({"Error": str(e)}))
205+
else:
206+
raise Exception(e)

0 commit comments

Comments
 (0)