Skip to content

Qualcomm AI Engine Direct - Suport batch prefill mode for llama3.2 #6983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2024

Conversation

chunit-quic
Copy link
Collaborator

  • Enable bert mode
  • Change input sequence of static_llama
  • Tag bert output as uint8
  • Unify both 1b and 3b in 1 runner
  • Add hybrid IO memory for llama3_2 runner
  • Align timer with llama

Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6983

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job

As of commit ef2e1e5 with merge base a347665 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 20, 2024
@cccclai
Copy link
Contributor

cccclai commented Nov 21, 2024

Hey do you mind sharing the command for AoT and runtime so I can try on my end?

@chunit-quic
Copy link
Collaborator Author

chunit-quic commented Nov 22, 2024

Hey do you mind sharing the command for AoT and runtime so I can try on my end?

Sure! Change different mode (kv or bert )by set up the argument model_mode.

python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ${ARCHIVE}/ -b build-android -H ${HOST} -s ${DEVICE}-m ${SOC} --checkpoint Llama3.2-1B-Instruct/consolidated.00.pth --params Llama3.2-1B-Instruct/params.json --tokenizer_model Llama3.2-1B-Instruct/tokenizer.model --prompt "<|start_header_id|>" --ptq 16a4w --temperature 0 --model_size 1B --seq_len 16  --model_mode bert

@chunit-quic chunit-quic marked this pull request as ready for review November 22, 2024 01:04
@cccclai
Copy link
Contributor

cccclai commented Nov 22, 2024

Hey do you mind sharing the command for AoT and runtime so I can try on my end?

Sure! Change different mode (kv or bert )by set up the argument model_mode.

python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ${ARCHIVE}/ -b build-android -H ${HOST} -s ${DEVICE}-m ${SOC} --checkpoint Llama3.2-1B-Instruct/consolidated.00.pth --params Llama3.2-1B-Instruct/params.json --tokenizer_model Llama3.2-1B-Instruct/tokenizer.model --prompt "<|start_header_id|>" --ptq 16a4w --temperature 0 --model_size 1B --seq_len 16  --model_mode bert

Ah I see - do you mind rename bert mode to batch_prefill? The context is that bert isn't a common name..

@chunit-quic chunit-quic force-pushed the dev1/chunit/bert_mode branch from 5dc7b3f to 0cff7c9 Compare November 22, 2024 01:39
@chunit-quic
Copy link
Collaborator Author

Ah I see - do you mind rename bert mode to batch_prefill? The context is that bert isn't a common name..

No problem. let me change it

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@chunit-quic chunit-quic changed the title Qualcomm AI Engine Direct - Suport bert mode for llama3.2 Qualcomm AI Engine Direct - Suport batch prefill mode for llama3.2 Nov 22, 2024
@cccclai
Copy link
Contributor

cccclai commented Nov 22, 2024

There are some errors here

executorch/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp:50:7: error: field 'eval_mode_' will be initialized after field 'stats_' [-Werror,-Wreorder-ctor]
   50 |       eval_mode_(eval_mode),
      |       ^~~~~~~~~~~~~~~~~~~~~
      |       stats_({})
   51 |       stats_({}) {
      |       ~~~~~~~~~~
      |       eval_mode_(eval_mode)

@chunit-quic
Copy link
Collaborator Author

executorch/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp:50:7: error: field 'eval_mode_' will be initialized after field 'stats_' [-Werror,-Wreorder-ctor]
50 | eval_mode_(eval_mode),
| ^~~~~~~~~~~~~~~~~~~~~
| stats_({})
51 | stats_({}) {
| ~~~~~~~~~~
| eval_mode_(eval_mode)

Thanks for pointing out. Fixed.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -137,7 +137,7 @@ def python_is_compatible():
"timm==1.0.7",
f"torchaudio==2.5.0.{NIGHTLY_VERSION}" if USE_PYTORCH_NIGHTLY else "torchaudio",
"torchsr==1.0.4",
"transformers==4.46.1",
"transformers==4.42.4", # TODO update back to 4.46.1 once the error is fixed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue with this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for confusing. This is for our internal CI. Removed.

One more thing: just in case you want to reproduce the performance profiling results right now, we are still checking and working on some related passes. It might be better to wait for our next profiling results to see if the execution times are aligned. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's bring more CI to oss, so they can be caught when creating the PR

@cccclai
Copy link
Contributor

cccclai commented Nov 23, 2024

Still has lint error...

@chunit-quic chunit-quic force-pushed the dev1/chunit/bert_mode branch from ff3ce36 to 08c4742 Compare November 25, 2024 01:50
@cccclai
Copy link
Contributor

cccclai commented Dec 1, 2024

test-llama-runner-qnn-linux seems failing, but it seems unrelated to this PR and it should be resolved in the main branch. Can you help with rebasing?

@cccclai
Copy link
Contributor

cccclai commented Dec 1, 2024

I'm getting following error:

I 00:00:00.000856 executorch:runner.cpp:54] creating module: model_path=llama3_2_qnn_batch_prefill.pte
I 00:00:00.000978 executorch:runner.cpp:56] creating runner: tokenizer_path=tokenizer.model
I 00:00:00.001265 executorch:runner.cpp:135] get_max_seq_len: 16
I 00:00:00.001331 executorch:runner.cpp:135] get_vocab_size: 128256
I 00:00:00.001370 executorch:runner.cpp:135] get_n_layers: 16
I 00:00:00.001419 executorch:runner.cpp:135] get_head_dim: 64
I 00:00:00.001476 executorch:runner.cpp:135] get_n_kv_heads: 8
I 00:00:00.443705 executorch:runner.cpp:81] creating io_memory
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Initializing HtpProvider

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[WARNING] [Qnn ExecuTorch]: Failed to interpret QNN context binary. Error code 30010. Try verifying binary with online-prepare format.
[ERROR] [Qnn ExecuTorch]: Failed to parse QNN Graph Info. The cache might be broken. Please consider to re-generate the cache.
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[ERROR] [Qnn ExecuTorch]: QNN context cache is invalid.
E 00:00:00.575785 executorch:QnnManager.cpp:303] Fail to configure Qnn context
E 00:00:00.575857 executorch:QnnExecuTorchBackend.cpp:60] Fail to initialize Qnn Manager
E 00:00:00.575926 executorch:method.cpp:109] Init failed for backend QnnBackend: 0x1
[WARNING] [Qnn ExecuTorch]:  <W> Backend 1 free cleanup called during process exit

[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!

The repro step is

  1. AoT
python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ./llama3_2_1B_512 -b build-android --compile_only -m "SM8650" --checkpoint /data/users/chenlai/models/llama3.2/1B/consolidated.00.pth --params /data/users/chenlai/models/llama3.2/1B/params.json --tokenizer_model /data/users/chenlai/models/llama3.2/1B/tokenizer.model --prompt "<|start_header_id|>" --ptq 16a4w --temperature 0 --model_size 1B --seq_len 16  --model_mode batch_prefill
  1. Runtime, after pushing the binary/library to the device and set up the env path.
./qnn_llama3_2_runner --model_path llama3_2_qnn_batch_prefill.pte --tokenizer_path tokenizer.model                                                

Joey Tsai added 8 commits December 2, 2024 07:11
- Enable bert mode
- Change input sequence of static_llama
- Tag bert output as uint8
- Unify both 1b and 3b in 1 runner
- Add hybrid IO memory for llama3_2 runner
- Align timer with llama
- Fix rebase conflict
- Change input dtype of calibration function
- Fix transformers version
- Refine pass quantization tagging function
- Rebase
@chunit-quic chunit-quic force-pushed the dev1/chunit/bert_mode branch from 767887d to ef2e1e5 Compare December 1, 2024 23:26
@chunit-quic
Copy link
Collaborator Author

test-llama-runner-qnn-linux seems failing, but it seems unrelated to this PR and it should be resolved in the main branch. Can you help with rebasing?

Sure, just rebased.

I'm getting following error:
[WARNING] [Qnn ExecuTorch]: Failed to interpret QNN context binary. Error code 30010. Try verifying binary with online-prepare format.
[ERROR] [Qnn ExecuTorch]: Failed to parse QNN Graph Info. The cache might be broken. Please consider to re-generate the cache.
[WARNING] [Qnn ExecuTorch]: Function not called, PrepareLib isn't loaded!
[ERROR] [Qnn ExecuTorch]: QNN context cache is invalid.

May I know which QNN version did you use? It seems to me that it might relate to PR6811. I just tested with smaller one(1layer) without encountering an error. My qnn sdk version is 2.26.1

@cccclai
Copy link
Contributor

cccclai commented Dec 1, 2024

I'm using the version downloaded from https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.26.0.240828.zip

Maybe let me get your latest pr and see if it passes. Seems like both batch_prefill and kv fails

@cccclai
Copy link
Contributor

cccclai commented Dec 1, 2024

Is the PR tested on qnn 2.28?

@chunit-quic
Copy link
Collaborator Author

chunit-quic commented Dec 2, 2024

Is the PR tested on qnn 2.28?

No, we test on qnn 2.26 previously. I'm running batch_prefill mode with 16 layers now with clean build again, and will do kv mode later.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@chunit-quic
Copy link
Collaborator Author

chunit-quic commented Dec 2, 2024

No, we test on qnn 2.26 previously. I'm running patch_prefill mode with 16 layers now with clean build again, and will do kv mode later.

Hi @cccclai
Just a quick update.

I run and succeed both kv and batch_prefill mode (16 layers, normal size) with the command below by changing the arg --model_mode. Didn't see the graph info error. My qnn sdk is 2.26.1.240827.

If you would like to export and execute seperately, maybe you could consider using --pre_gen_pte arg of llama.py. It will help you to push libs and generate command on device.

Please let me know if you are still facing the error. Would try to find it out. :)

 python examples/qualcomm/oss_scripts/llama3_2/llama.py -a ./${bert_test} -b build-android -H ${HOST} -s ${DEVICE} -m "SM8650" --checkpoint ${Llama3.2-1B-Instruct}/consolidated.00.pth --params ${Llama3.2-1B-Instruct}/params.json --tokenizer_model ${Llama3.2-1B-Instruct}/tokenizer.model --prompt "<|start_header_id|>" --ptq 16a4w --temperature 0 --model_size 1B --seq_len 16  --model_mode batch_prefill

@cccclai
Copy link
Contributor

cccclai commented Dec 3, 2024

Sort of orthogonal - did we verify the runner cpp correctness? We have a calibrated quantized model which shows reasonable result, but the runner gives purely non-sense result.

@chunit-quic
Copy link
Collaborator Author

chunit-quic commented Dec 3, 2024

Sort of orthogonal - did we verify the runner cpp correctness? We have a calibrated quantized model which shows reasonable result, but the runner gives purely non-sense result.

Yes, we performed some correctness checks.

  1. Given a sequence of 15 tokens in prefill mode, we obtained its results (15 tokens). We then fed the same prompts to kv mode and obtained its results one by one (also 15 tokens in the end). The tokens from kv and prefill modes seem to be almost the same. (llama3_2 1b fp)
  2. We have another internal PR based on this one to check correctness using stories 110M. The flow involves using prefill mode for prompts and kv mode to generate tokens based on the logits and kv cache from prefill mode. The results seem to be the same as pure kv mode. (fb)

We will update another PR based on this one recently, which supports hybrid mode. Maybe we can check correctness using that PR? Any other ideas are appreciated. :D

@cccclai
Copy link
Contributor

cccclai commented Dec 3, 2024

Ah yes, that will be great. If we enable stories model, then we can add it to CI easily

@cccclai
Copy link
Contributor

cccclai commented Dec 3, 2024

With that being said, should we just aim to merge that PR instead? and keep this PR on hold?

@chunit-quic
Copy link
Collaborator Author

With that being said, should we just aim to merge that PR instead? and keep this PR on hold?

No, I would say we can merge this one first. Then the upcoming PR will have less code change.

@cccclai cccclai added partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm release notes: backends [DO NOT USE] Changes to any of the backend delegates labels Dec 3, 2024
@cccclai cccclai merged commit d89d3e7 into pytorch:main Dec 3, 2024
42 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm release notes: backends [DO NOT USE] Changes to any of the backend delegates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants