Skip to content

Qualcomm AI Engine Direct Backend #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed

Conversation

haowhsu-quic
Copy link
Collaborator

@haowhsu-quic haowhsu-quic commented Sep 26, 2023

Hi there,

This commit adds QNN backend. Please check setup.md and README under backends/qnn for more details.

We add pybind dependency at the top level, because we guess you might need it in the near future. If it's not the case, please feel free to tell us. We'll move it under backends/qnn.

Below is some SoC model mapping:

  • SM8550 -> Snapdragon 8 Gen 2
  • SM8475 -> Snapdragon 8 Gen 1+
  • SM8450 -> Snapdragon 8 Gen 1

Please feel free to contact us if you find any problem.


Co-authored-by: shewu-quic [email protected]
Co-authored-by: chunit-quic [email protected]
Co-authored-by: chiwwang [email protected]
Co-authored-by: harshs-qti [email protected]

@netlify
Copy link

netlify bot commented Sep 26, 2023

Deploy Preview for resplendent-gnome-14e531 canceled.

Name Link
🔨 Latest commit 6567280
🔍 Latest deploy log https://app.netlify.com/sites/resplendent-gnome-14e531/deploys/651d99a329b2c30008e01efe

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2023
@chiwwang
Copy link
Contributor

@haowhsu-quic , missing co-author harshs-qti [email protected]

Harsh helped us to review some documents.

@@ -0,0 +1,344 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for meta here either and probably other plcaes as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thanks for pointing out.

int inference_index = 0;
double elapsed_time = 0;
while (std::getline(input_list, file_path)) {
auto input_files = split(file_path, " ");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ensure input_files.size() == num_inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks for pointing out.

// Warm up
if (FLAGS_warm_up) {
ET_LOG(Info, "Perform 3 inference for warming up");
for (int i = 0; i < 3; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use FLAG_warmup from arg, or just use some const define instead of directly using 3

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to integer for user to specify, thank you.

Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move

backends/qnn/examples to top-level examples/qualcomm instead?

Still reviewing...


## Compile a model

### for executorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using ExecuTorch (with E and T capitalized) convention, at least in documentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scanned and changed for all related keywords, thanks for pointing out.

You can find `qnn_executor_runner` under `build_android/backends/qnn`.


## Compile a model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you continue with the 'Step X' enumeration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, please take a look at the new version, thanks.


To gain the best performance, we also employ quantization:

```python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you just move this code under example, that just takes the model file name?

so that later we can test CI etc and make sure it doesn't go out of date.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, add a new example as what examples/export/export_example.py does for generating QNN compatible pte.

# LICENSE file in the root directory of this source tree.

import sys

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could mention what this file is about in the docblock

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, will add docstring incrementally in the future.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pr looks great! All codes are well organized and easy to read. Thank you! Just left some comments on the naming probably due to rebase.

def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> bytes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> bytes:
) -> PreprocessResult:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks for pointing out.

exir_ops.edge.aten.full.default,
]

white_list_operator = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
white_list_operator = [
allow_list_operator = [

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thank you.

self.partition_tags[delegation_tag] = self.delegation_spec

# override
def partition(self, exported_program: ExportedProgram) -> torch.fx.GraphModule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def partition(self, exported_program: ExportedProgram) -> torch.fx.GraphModule:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks for pointing out.

source_n = n.args[0]

# To make constant value/tensor be tagged as delegatable during partition
if source_n.op == "get_attr":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is source_n meant to be a call_delegate node or a regular call_function node? Just would like make sure we're using get_attr correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. It is a call delegate node.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a regular call_function node. We use this to bring all constant params from get_attr nodes together.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some early comments. still reviewing


`qnn_executor_runner` is an executable running the compiled model.

You might want to ensure the correct `flatc`. `flatc` can be built along with the above step. For example, we can find `flatc` in `build_x86_64/third-party/flatbuffers/`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

executorch steup reqquires flatbuffers to be installed as listed here https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md. Is that sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flatc from the current setup is a bit problematic. I update it in #447

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We encountered compilation issue before due to incompatible flatc version(conda installed vs submodule). We might recommend user to build it from source.

example_input = mv2.get_example_inputs()
model = mv2.get_eager_model().eval()

captured_program = exir.capture(model, example_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to include quantization api calls here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, it is listed below. You dont need to have "non-delegated" example listed here. Also note that this is setup.md and so you probably want the examples in either examples folder's readme.md or in qnn/readme.md file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks for the suggestion.

@@ -0,0 +1,251 @@
# Setting up QNN Backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except some minor nits, instructions here are great

inputs,
qnn_capture_config(),
).to_edge(qnn_edge_config())
return edge_prog.transform(qnn_partitioner_passes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

things of this nature, which is non-standard, in export workflow, is something that should be user visible.

Hence my preference is that you use the utils from examples/export/utils.py directly and add edge_prog.transform(...) in the examples and readmes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will decompose all this to the user and remove capture_program ultimately.



def qnn_capture_config():
return exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need separate capture_config? I see that you wan _unlift=True. I think this is fine for now, but we should try and make it work default capture, as these options are going away.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This argument means a lot to us. We hope to prefetch known tensors as many as possible instead of treating them as runtime arguments. This would be important for performing graph optimization inside QNN.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, _unlift=True, doesnt necessarily mean that the parameters of the model actually become arguents to the function. 100% agree with you that otherwise, it becomes a performance issue. If you look at xnnpack delegation example, it actually works with exported_program with "lifted" parameters, but during delegation it consumes all the parameters and only those which are input are supplied.

Lets go over this in the next meeting, as longer term that is the route that we need to take

"SM8475": PyQnnManager.QcomModel.SM8475,
"SM8550": PyQnnManager.QcomModel.SM8550,
}
backend_type = CompileSpec("backend_type", bytes([2]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the constants be named better?

Copy link
Collaborator

@shewu-quic shewu-quic Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other accelerators in QNN, but it is indeed inappropriate to use backend_type in ExecuTorch. It should be changed to accelerator_type in the future. How do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, the constant in bytes([2]) has been changed into meaningful enumeration.

@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you directly use this from backend/transforms. There shouldnt be a need to copy paste this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks for pointing out.

name = arg_schema.name
value = quant_attrs[name]
if type(arg_schema.type) == torch.tensor and type(value) in [int, float]:
value = torch.tensor(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the node added here is also consumed by delegate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we use it to quantize inputs, or dequantize the output. These added node do exist in our backend IR.

@kimishpatel kimishpatel requested a review from sxu September 26, 2023 23:03
@facebook-github-bot
Copy link
Contributor

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.


# To make constant value/tensor be tagged as delegatable during partition
if source_n.op == "get_attr":
source_n.meta["source_fn_stack"] = list(n.users.keys())[0].meta.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary to do in this PR, but if you update this PR, please include "TODO: remove this hack as source_fn_stack is internal implementation detail of torch.export". We will follow up on this. There is some issue as to why this metadata is not preserved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added, will keep tracking on this, too.

Tsai, Chun-I (Joey) and others added 2 commits October 4, 2023 12:30
- Add TODO to remove the source_fn_stack
- Fix the comment of fold_qdq pass
- Remove redundant ops in annotate_quant_attrs
- move qnn_executor_runner to examples/backend/qualcomm
- move qualcomm/runtime/wrappers to qualcomm/aot/wrappers
- move qualcomm/runtime/python to qualcomm/aot/python
- change file structure in README accordingly
Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the suggested changes.

One last minor request.

Would you mind moving examples/backends/qualcomm to just examples/qualcomm?

And move all the relevant scripts (e.g. export_example.py) to examples/qualcomm/scripts/...?

Also, keep the executor_runner in examples/qualcomm/executor_runner/ as is.

Currently we are doing refactoring within in our examples folder, and it'd be great if qualcomm had similar structure as us.

Otherwise, LGTM

cc Will let @kimishpatel take a one last look.

@haowhsu-quic
Copy link
Collaborator Author

Thanks for making the suggested changes.

One last minor request.

Would you mind moving examples/backends/qualcomm to just examples/qualcomm?

And move all the relevant scripts (e.g. export_example.py) to examples/qualcomm/scripts/...?

Also, keep the executor_runner in examples/qualcomm/executor_runner/ as is.

Currently we are doing refactoring within in our examples folder, and it'd be great if qualcomm had similar structure as us.

Otherwise, LGTM

cc Will let @kimishpatel take a one last look.

Thanks for making the suggested changes.

One last minor request.

Would you mind moving examples/backends/qualcomm to just examples/qualcomm?

And move all the relevant scripts (e.g. export_example.py) to examples/qualcomm/scripts/...?

Also, keep the executor_runner in examples/qualcomm/executor_runner/ as is.

Currently we are doing refactoring within in our examples folder, and it'd be great if qualcomm had similar structure as us.

Otherwise, LGTM

cc Will let @kimishpatel take a one last look.

Hi Mergen, thank you for giving us great comments.
For relevant scripts, do you mean we should also put all the examples like mobilenet_v2.py, inception_v4.py, etc. to examples/qualcomm/scripts?

@haowhsu-quic haowhsu-quic reopened this Oct 4, 2023
@mergennachin
Copy link
Contributor

mergennachin commented Oct 4, 2023

Yes, any user facing main scripts where they have to run to see the example models e2e is preferred to live in "examples/qualcomm/scripts/*"

@mergennachin
Copy link
Contributor

Also don't close the PR yet, someone from Meta will do the merging properly. Most likely Kimish

@@ -0,0 +1,205 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You dont need Meta here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thank you.

@@ -0,0 +1,774 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this there because of file being copied from somewhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, things look very different now, thank you.

@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if it is right to include meta here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thank you.

@@ -0,0 +1,1209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not sure if meta needs to be here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thank you.

@kimishpatel
Copy link
Contributor

Also, please take a look at various copyrights. I have commented on some. There are a few places Meta is in copyright but I am not sure if there is a need for it.

@facebook-github-bot
Copy link
Contributor

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

- move examples/backend/qualcomm -> example/qualcomm
- change documents, CMakeLists.txt accordingly
- remove copyright from Meta
@haowhsu-quic
Copy link
Collaborator Author

Also don't close the PR yet, someone from Meta will do the merging properly. Most likely Kimish

Sorry, misclicked. Examples are rearranged now.

Also, please take a look at various copyrights. I have commented on some. There are a few places Meta is in copyright but I am not sure if there is a need for it.

Thank you for reminding this, I've scanned for all related files and make sure all the copyright is suitable.

@haowhsu-quic
Copy link
Collaborator Author

Also don't close the PR yet, someone from Meta will do the merging properly. Most likely Kimish

Sorry, misclicked. Examples are rearranged now, thank you for the suggestion.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I will try to land this by importing it internally.

@facebook-github-bot
Copy link
Contributor

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment on lines +37 to +39
python mobilenet_v2.py -s <device_serial> -m "SM8550" -b path/to/build_android/ -d /path/to/imagenet-mini/val

python deeplab_v3.py -s <device_serial> -m "SM8550" -b path/to/build_android/ --download
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are grouping the scripts under qualcomm/scripts/, it's nice to ask users to cd $EXECUTORCH_ROOT/examples/qualcomm/scripts first so that users can copy/past the commands here and run it hassle-free on their ends

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a good siggestion. We'll change this in following PR, thank you.

@@ -7,7 +7,11 @@
import logging

import torch
from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haowhsu-quic Just want to confirm that Qualcomm is not interested in using deeplabc3_resnet50 right? We may want to clean it up later since there is already a resnet50 model in our model portfolio.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we'll use resnet101 backbone. Resnet50 does not perform well in our experiments.

Copy link
Contributor

@guangy10 guangy10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The demos code structure in the examples/ looks good to me!

@facebook-github-bot
Copy link
Contributor

@kimishpatel merged this pull request in ef97148.

Gasoonjia pushed a commit that referenced this pull request Jul 30, 2024
* Fix runner-et on Mac and Android

* Fix typo

* Avoid fatal message

* Change fatal to warning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.