-
Notifications
You must be signed in to change notification settings - Fork 608
Qualcomm AI Engine Direct - FbNet enablement #2706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import json | ||
import os | ||
import re | ||
import sys | ||
from multiprocessing.connection import Client | ||
|
||
import numpy as np | ||
import timm | ||
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype | ||
from executorch.examples.qualcomm.scripts.inception_v4 import get_dataset | ||
from executorch.examples.qualcomm.scripts.utils import ( | ||
build_executorch_binary, | ||
make_output_dir, | ||
setup_common_args_and_variables, | ||
SimpleADB, | ||
topk_accuracy, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = setup_common_args_and_variables() | ||
parser.add_argument( | ||
"-a", | ||
"--artifact", | ||
help="path for storing generated artifacts by this example. Default ./fbnet", | ||
default="./fbnet", | ||
type=str, | ||
) | ||
|
||
parser.add_argument( | ||
"-d", | ||
"--dataset", | ||
help=( | ||
"path to the validation folder of ImageNet dataset. " | ||
"e.g. --dataset imagenet-mini/val " | ||
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" | ||
), | ||
type=str, | ||
required=True, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
if not args.compile_only and args.device is None: | ||
raise RuntimeError( | ||
"device serial is required if not compile only. " | ||
"Please specify a device serial by -s/--device argument." | ||
) | ||
|
||
# ensure the working directory exist. | ||
os.makedirs(args.artifact, exist_ok=True) | ||
|
||
instance = timm.create_model("fbnetc_100", pretrained=True).eval() | ||
|
||
data_num = 100 | ||
inputs, targets, input_list = get_dataset( | ||
dataset_path=f"{args.dataset}", | ||
data_size=data_num, | ||
) | ||
|
||
pte_filename = "fbnet" | ||
|
||
build_executorch_binary( | ||
instance, | ||
inputs[0], | ||
args.model, | ||
f"{args.artifact}/{pte_filename}", | ||
inputs, | ||
quant_dtype=QuantDtype.use_8a8w, | ||
) | ||
|
||
if args.compile_only: | ||
sys.exit(0) | ||
|
||
adb = SimpleADB( | ||
qnn_sdk=os.getenv("QNN_SDK_ROOT"), | ||
artifact_path=f"{args.build_folder}", | ||
pte_path=f"{args.artifact}/{pte_filename}.pte", | ||
workspace=f"/data/local/tmp/executorch/{pte_filename}", | ||
device_id=args.device, | ||
host_id=args.host, | ||
soc_model=args.model, | ||
) | ||
adb.push(inputs=inputs, input_list=input_list) | ||
adb.execute() | ||
|
||
# collect output data | ||
output_data_folder = f"{args.artifact}/outputs" | ||
make_output_dir(output_data_folder) | ||
|
||
output_raws = [] | ||
|
||
def post_process(): | ||
for f in sorted( | ||
os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) | ||
): | ||
filename = os.path.join(output_data_folder, f) | ||
if re.match(r"^output_[0-9]+_[1-9].raw$", f): | ||
os.remove(filename) | ||
else: | ||
output = np.fromfile(filename, dtype=np.float32) | ||
output_raws.append(output) | ||
|
||
adb.pull(output_path=args.artifact, callback=post_process) | ||
|
||
# top-k analysis | ||
predictions = [] | ||
for i in range(data_num): | ||
predictions.append( | ||
np.fromfile( | ||
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 | ||
) | ||
) | ||
|
||
k_val = [1, 5] | ||
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] | ||
if args.ip and args.port != -1: | ||
with Client((args.ip, args.port)) as conn: | ||
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) | ||
else: | ||
for i, k in enumerate(k_val): | ||
print(f"top_{k}->{topk[i]}%") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pip install timm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,9 @@ def create_device_inputs(example_inputs, use_kv_cache): | |
|
||
|
||
if __name__ == "__main__": | ||
print( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you run into any issue with the script? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I test it last week and it seems ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @cccclai, We found some unideal behavior in our CI. For the following reasons we think it's better to have this warning:
python dummy_llama2.py --ptq 8a8w ...
python dummy_llama2.py --ptq 16a4w ...
Would you mind to share your command please? We can also reproduce it and find what the difference. Thanks! :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I take my word back - I just try export the model and see this error when I try to load the model in the runtime
Any chance you know the reason? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh also I think the code change in llama_transformer.py might be the culprit when the issue you saw. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the error message might be just for me because I only have SM8450. Just open an issue here #2788 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out the possibility. We will investigate it later.
We will find a 8450 device and try to reproduce it. Once we have any news we will reply at issue 2788. Thank you for report. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I ask what device you've been using? Is it SM8450? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I ususally work on SM8550. I don't evevn test a 8450 device personally. |
||
"[WARNING] The module of llama is changing frequently. This script might not work" | ||
) | ||
parser = setup_common_args_and_variables() | ||
parser.add_argument( | ||
"-a", | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason we turn it on? I guess I didn't realize it was OFF before
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We explicitly turn OFF it before.
Because recently PR 2466 turn it off by default, we don't need to set it again here.