Skip to content

Qualcomm AI Engine Direct - GA efficientnet #11212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions backends/qualcomm/builders/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class AvgPool2d(NodeVisitor):
def __init__(self, *args) -> None:
super().__init__(*args)

def _get_filter_size(self, node):
filter_size = cast(List[int], node.args[1])
if len(filter_size) == 1:
filter_size = filter_size + filter_size
return filter_size

def define_node(
self,
node: torch.fx.Node,
Expand All @@ -46,31 +52,44 @@ def define_node(
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

pt_ceil_mode = node.args[4] if len(node.args) >= 4 else False

# kernel info
filter_size = cast(List[int], node.args[1])
if len(filter_size) == 1:
filter_size = filter_size + filter_size
input_shape = input_node.meta["val"].shape
input_h, input_w = input_shape[2], input_shape[3]
filter_size = self._get_filter_size(node)
if pt_ceil_mode:
# filter_size might larger than input_h, input_w, use min of them
filter_size = [min(filter_size[0], input_h), min(filter_size[1], input_w)]
filter_size_shape = [len(filter_size)]

# stride info - default to kernel_size if not given
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
if len(stride) == 1:
stride = stride + stride
stride_shape = [len(stride)]

padding = [0, 0]
if len(node.args) > 3:
padding = cast(List[int], node.args[3])
if len(padding) == 1:
padding = padding + padding
if pt_ceil_mode:
ori_filter_h, ori_filter_w = self._get_filter_size(node)
padding = [
0 if ori_filter_h > input_h else padding[0],
0 if ori_filter_w > input_w else padding[1],
]

padding_shape = [len(padding), len(padding)]

# if ceil mode is True, use ceil instead of floor to compute the output shape
mode = OpPoolAvg2d.RoundingMode.FLOOR
if len(node.args) > 4:
ceil_mode = cast(bool, node.args[4])
if ceil_mode:
mode = OpPoolAvg2d.RoundingMode.CEIL
mode = (
OpPoolAvg2d.RoundingMode.CEIL
if pt_ceil_mode
else OpPoolAvg2d.RoundingMode.FLOOR
)

# stride info - default to kernel_size if not given
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
if len(stride) == 1:
stride = stride + stride
stride_shape = [len(stride)]

count_include_pad = True
if len(node.args) > 5:
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:
@register_annotator(
[
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose1d.default,
Expand Down
9 changes: 5 additions & 4 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ def forward(self, x, y):


class AvgPoolModule(torch.nn.Module):
def __init__(self):
def __init__(self, kernel_size, stride, padding, ceil_mode):
super().__init__()
self.avgPool = torch.nn.AvgPool2d(
kernel_size=(2, 2),
padding=(1, 1),
stride=(1, 1),
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=False,
)

Expand Down
69 changes: 62 additions & 7 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,19 @@ def test_qnn_backend_argmin(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool2d(self):
module = AvgPoolModule() # noqa: F405
sample_input = (torch.randn(1, 3, 2, 2),)
self.lower_module_and_test_output(module, sample_input)
modules = [
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405
AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405
] # noqa: F405
sample_inputs = [
(torch.randn(1, 3, 2, 2),),
(torch.randn(1, 1280, 7, 7),),
(torch.randn(1, 1280, 7, 7),),
]
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_inputs[i])

def test_qnn_backend_batch_norm(self):
modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405
Expand Down Expand Up @@ -1271,10 +1281,20 @@ def test_qnn_backend_argmin(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool2d(self):
module = AvgPoolModule() # noqa: F405
sample_input = (torch.randn(1, 3, 2, 2),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
modules = [
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405
AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405
] # noqa: F405
sample_inputs = [
(torch.randn(1, 3, 2, 2),),
(torch.randn(1, 1280, 7, 7),),
(torch.randn(1, 1280, 7, 7),),
]
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_inputs[i])
self.lower_module_and_test_output(module, sample_inputs[i])

def test_qnn_backend_batch_norm(self):
modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405
Expand Down Expand Up @@ -3864,6 +3884,41 @@ def test_dino_v2(self):
self.assertGreaterEqual(msg["top_1"], 70)
self.assertGreaterEqual(msg["top_5"], 85)

def test_efficientnet(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientnet.py"
"--dataset",
self.image_dataset,
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
]
if self.host:
cmds.extend(["--host", self.host])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["top_1"], 70)
self.assertGreaterEqual(msg["top_5"], 85)

def test_efficientSAM(self):
if not self.required_envs(
[self.image_dataset, self.pretrained_weight, self.oss_repo]
Expand Down
145 changes: 145 additions & 0 deletions examples/qualcomm/oss_scripts/efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

import numpy as np

import torch
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.examples.qualcomm.utils import (
build_executorch_binary,
get_imagenet_dataset,
make_output_dir,
parse_skip_delegation_node,
setup_common_args_and_variables,
SimpleADB,
topk_accuracy,
)
from transformers import AutoModelForImageClassification


def main(args):
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)

# ensure the working directory exist.
os.makedirs(args.artifact, exist_ok=True)

if not args.compile_only and args.device is None:
raise RuntimeError(
"device serial is required if not compile only. "
"Please specify a device serial by -s/--device argument."
)

data_num = 100
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
data_size=data_num,
image_shape=(256, 256),
crop_size=224,
)

module = (
AutoModelForImageClassification.from_pretrained("google/efficientnet-b0")
.eval()
.to("cpu")
)
pte_filename = "efficientnet_qnn_q16"
build_executorch_binary(
module.eval(),
inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_16a16w,
shared_buffer=args.shared_buffer,
)

if args.compile_only:
return

adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
)
adb.push(inputs=inputs, input_list=input_list)
adb.execute()

# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)

adb.pull(output_path=args.artifact)

# top-k analysis
predictions = []
for i in range(data_num):
predictions.append(
np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
)

k_val = [1, 5]
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
else:
for i, k in enumerate(k_val):
print(f"top_{k}->{topk[i]}%")


if __name__ == "__main__":
parser = setup_common_args_and_variables()

parser.add_argument(
"-d",
"--dataset",
help=(
"path to the validation folder of ImageNet dataset. "
"e.g. --dataset imagenet-mini/val "
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
),
type=str,
required=False,
)

parser.add_argument(
"-a",
"--artifact",
help="path for storing generated artifacts by this example. "
"Default ./efficientnet",
default="./efficientnet",
type=str,
)

args = parser.parse_args()
try:
main(args)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)
Loading