Skip to content

Qualcomm AI Engine Direct - Optimization in static llama #6849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

shewu-quic
Copy link
Collaborator

summary:

  • Fuse rms norm
  • Improve performance of div op
  • Fixed 16a8w annotation for matmul op

summary:
- Fuse rms norm
- Improve performance of div op
- Fixed 16a8w annotation for matmul op
Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6849

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 1732d06 with merge base 21eecff (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,

Here is some optimization to reproduce the performance.
If you have any question, please let me know.
We will create a PR to enable prefill on static llama ASAP.

Thanks a lot

cccclai
cccclai previously approved these changes Nov 15, 2024
Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Still working on repro the number...

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Nov 19, 2024

Here is my latency number, without this pr, on main

PyTorchObserver {"prompt_tokens":16,"generated_tokens":111,"model_load_start_ms":1731975750978,"model_load_end_ms":1731975751393,"inference_start_ms":1731975751393,"inference_end_ms":1731975753792,"prompt_eval_end_ms":1731975751694,"first_token_ms":1731975751715,"aggregate_sampling_time_ms":87,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:02.900777 executorch:runner.cpp:254] 	Prompt Tokens: 16    Generated Tokens: 111
I 00:00:02.900790 executorch:runner.cpp:260] 	Model Load Time:		0.415000 (seconds)
I 00:00:02.900817 executorch:runner.cpp:270] 	Total inference time:		2.399000 (seconds)		 Rate: 	46.269279 (tokens/second)
I 00:00:02.900830 executorch:runner.cpp:278] 		Prompt evaluation:	0.301000 (seconds)		 Rate: 	53.156146 (tokens/second)
I 00:00:02.900841 executorch:runner.cpp:289] 		Generated 111 tokens:	2.098000 (seconds)		 Rate: 	52.907531 (tokens/second)
I 00:00:02.900880 executorch:runner.cpp:297] 	Time to first generated token:	0.322000 (seconds)
I 00:00:02.900888 executorch:runner.cpp:304] 	Sampling time over 127 tokens:	0.087000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!

with this pr:

PyTorchObserver {"prompt_tokens":16,"generated_tokens":111,"model_load_start_ms":1731975765318,"model_load_end_ms":1731975765720,"inference_start_ms":1731975765720,"inference_end_ms":1731975767756,"prompt_eval_end_ms":1731975765979,"first_token_ms":1731975765995,"aggregate_sampling_time_ms":46,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:02.520441 executorch:runner.cpp:254] 	Prompt Tokens: 16    Generated Tokens: 111
I 00:00:02.520450 executorch:runner.cpp:260] 	Model Load Time:		0.402000 (seconds)
I 00:00:02.520458 executorch:runner.cpp:270] 	Total inference time:		2.036000 (seconds)		 Rate: 	54.518664 (tokens/second)
I 00:00:02.520480 executorch:runner.cpp:278] 		Prompt evaluation:	0.259000 (seconds)		 Rate: 	61.776062 (tokens/second)
I 00:00:02.520487 executorch:runner.cpp:289] 		Generated 111 tokens:	1.777000 (seconds)		 Rate: 	62.464828 (tokens/second)
I 00:00:02.520493 executorch:runner.cpp:297] 	Time to first generated token:	0.275000 (seconds)
I 00:00:02.520498 executorch:runner.cpp:304] 	Sampling time over 127 tokens:	0.046000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend

@shewu-quic
Copy link
Collaborator Author

Did It test on Gen 3?
I think 62~tok/sec align with our side.
I will create another optimization which improves a little bit accuracy this week since the quantization annotations of the linear are wrong.

@cccclai
Copy link
Contributor

cccclai commented Nov 19, 2024

I set the prompt and seq_len, just to be accurate, and following is the number:

OP595DL1:/data/local/tmp/static_llama $./qnn_llama3_2_runner --model_path llama3_2_qnn_opt.pte --tokenizer_path tokenizer.model --seq_len 512 --temperature 0 --prompt "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nwhat is 1 + 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"                                                                          <
I 00:00:00.000449 executorch:runner.cpp:59] creating module: model_path=llama3_2_qnn_main.pte
I 00:00:00.000497 executorch:runner.cpp:61] creating runner: tokenizer_path=tokenizer.model
I 00:00:00.070977 executorch:runner.cpp:80] creating io_memory
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Initializing HtpProvider
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
PyTorchObserver {"prompt_tokens":28,"generated_tokens":483,"model_load_start_ms":1731976447760,"model_load_end_ms":1731976448177,"inference_start_ms":1731976448177,"inference_end_ms":1731976456396,"prompt_eval_end_ms":1731976448627,"first_token_ms":1731976448643,"aggregate_sampling_time_ms":228,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:08.709662 executorch:runner.cpp:254] 	Prompt Tokens: 28    Generated Tokens: 483
I 00:00:08.709674 executorch:runner.cpp:260] 	Model Load Time:		0.417000 (seconds)
I 00:00:08.709700 executorch:runner.cpp:270] 	Total inference time:		8.219000 (seconds)		 Rate: 	58.766273 (tokens/second)
I 00:00:08.709710 executorch:runner.cpp:278] 		Prompt evaluation:	0.450000 (seconds)		 Rate: 	62.222222 (tokens/second)
I 00:00:08.709718 executorch:runner.cpp:289] 		Generated 483 tokens:	7.769000 (seconds)		 Rate: 	62.170163 (tokens/second)
I 00:00:08.709727 executorch:runner.cpp:297] 	Time to first generated token:	0.466000 (seconds)
I 00:00:08.709733 executorch:runner.cpp:304] 	Sampling time over 511 tokens:	0.228000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
OP595DL1:/data/local/tmp/static_llama $
OP595DL1:/data/local/tmp/static_llama $
OP595DL1:/data/local/tmp/static_llama $
OP595DL1:/data/local/tmp/static_llama $ ./qnn_llama3_2_runner --model_path llama3_2_qnn_main.pte --tokenizer_path tokenizer.model --seq_len 512 --temperature 0 --prompt "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n>
I 00:00:00.001121 executorch:runner.cpp:59] creating module: model_path=llama3_2_qnn_opt.pte
I 00:00:00.001179 executorch:runner.cpp:61] creating runner: tokenizer_path=tokenizer.model
I 00:00:00.073712 executorch:runner.cpp:80] creating io_memory
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Initializing HtpProvider
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
PyTorchObserver {"prompt_tokens":15,"generated_tokens":496,"model_load_start_ms":1731976623611,"model_load_end_ms":1731976624001,"inference_start_ms":1731976624001,"inference_end_ms":1731976633120,"prompt_eval_end_ms":1731976624260,"first_token_ms":1731976624278,"aggregate_sampling_time_ms":307,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:09.582835 executorch:runner.cpp:254] 	Prompt Tokens: 15    Generated Tokens: 496
I 00:00:09.582844 executorch:runner.cpp:260] 	Model Load Time:		0.390000 (seconds)
I 00:00:09.582852 executorch:runner.cpp:270] 	Total inference time:		9.119000 (seconds)		 Rate: 	54.391929 (tokens/second)
I 00:00:09.582871 executorch:runner.cpp:278] 		Prompt evaluation:	0.259000 (seconds)		 Rate: 	57.915058 (tokens/second)
I 00:00:09.582878 executorch:runner.cpp:289] 		Generated 496 tokens:	8.860000 (seconds)		 Rate: 	55.981941 (tokens/second)
I 00:00:09.582885 executorch:runner.cpp:297] 	Time to first generated token:	0.277000 (seconds)
I 00:00:09.582890 executorch:runner.cpp:304] 	Sampling time over 511 tokens:	0.307000 (seconds)
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

@cccclai
Copy link
Contributor

cccclai commented Nov 19, 2024

I think 62~tok/sec align with our side.

Yeah, one plus 12 (sm8650, 16G RAM)

@cccclai
Copy link
Contributor

cccclai commented Nov 19, 2024

Did It test on Gen 3? I think 62~tok/sec align with our side. I will create another optimization which improves a little bit accuracy this week since the quantization annotations of the linear are wrong.

oh do you know which one is wrong? Naveen just sent an email regarding the accuracy issue between fake quantized model vs on device model, can it be related?

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Nov 19, 2024

I am not sure does it related. I find we annotate linear with MovingAverageMinMaxObserver, but annotate other ops with MinMaxObserver in llama. It will make quant attr of some ops unreasonable. Such as transpose op, we should expect the same quant attr between I/O. But we get different value due to observer mismatch.
https://github.com/CodeLinaro/executorch/blob/1732d069f50203eeab12d522c6de8eb3ddfe7314/backends/qualcomm/quantizer/utils.py#L428

@shewu-quic shewu-quic closed this Nov 19, 2024
@shewu-quic shewu-quic reopened this Nov 19, 2024
@pytorch-bot pytorch-bot bot dismissed cccclai’s stale review November 19, 2024 01:44

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,
Sorry about that I accidentally turned off PR :(

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment on lines +93 to +95
elif node.target == torch.ops.aten.cat.default:
annotate_cat(node, quantization_config_8a8w)
node = node.args[0][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What pattern is this trying to capture?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following pattern.

                        q (16 bits) -------\
                                 matmul op (16 bits)
past k / v (8 bits) -------\
                     cat op (8 bits) ----/
new k / v (8 bits)---------/
(transpose after k)

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai cccclai added the release notes: backends [DO NOT USE] Changes to any of the backend delegates label Nov 20, 2024
@facebook-github-bot facebook-github-bot merged commit 709e739 into pytorch:main Nov 20, 2024
75 of 78 checks passed
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Nov 20, 2024
Summary: Looks like it's added in  pytorch#6849, maybe it was using the old api for the default 8bit quantization

Differential Revision: D66219251
@cccclai cccclai mentioned this pull request Nov 20, 2024
cccclai added a commit that referenced this pull request Nov 20, 2024
Summary: Looks like it's added in  #6849, maybe it was using the old api for the default 8bit quantization

Differential Revision: D66219251
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: backends [DO NOT USE] Changes to any of the backend delegates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants