Skip to content

sycl: add usage of enqueue_functions extension #14244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2025

Conversation

s-Nick
Copy link
Collaborator

@s-Nick s-Nick commented Jun 17, 2025

This PR enables the use of sycl_ext_oneapi_enqueue_functions extension. The goal is to submit kernel to the queue without keep track of the resulting event, since SYCL backend does not rely on them due to the in_order queue.

This patch provides good performance improvement on small models and does not impact negatively performance on larger models.
All test from test-backend-ops pass.

Battlemage B580 results on Linux with icpx2025.1

model size params backend ngl sm test master t/s PR t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none pp512 7471.23 ± 62.35 7560.18 ± 44.79
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none tg128 135.97 ± 2.59 160.35 ± 1.18
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none pp512 7531.73 ± 27.27 7657.91 ± 55.43
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none tg128 119.68 ± 0.75 140.02 ± 0.57
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none pp512 2167.02 ± 0.92 2179.78 ± 4.04
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none tg128 65.26 ± 0.79 74.26 ± 0.15
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none pp512 2199.95 ± 5.83 2222.65 ± 3.65
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none tg128 55.31 ± 0.41 60.64 ± 0.27
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none pp512 5715.46 ± 23.04 5779.09 ± 19.23
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none tg128 93.04 ± 2.14 106.92 ± 0.26
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none pp512 3048.85 ± 8.22 3070.99 ± 10.90
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none tg128 92.00 ± 0.99 107.82 ± 0.20
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none pp512 3159.13 ± 11.88 3174.88 ± 8.29
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none tg128 70.79 ± 0.56 79.58 ± 0.15
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none pp512 1493.59 ± 1.30 1494.54 ± 2.45
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none tg128 23.40 ± 0.10 24.14 ± 0.04

build: bb157ae (5695)

Lunar Lake results on Linux with icpx2025.1

model size params backend ngl threads sm test master t/s PR t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 8 none pp512 1493.16 ± 56.87 1506.92 ± 8.22
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 8 none tg128 54.61 ± 0.42 58.78 ± 0.50
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 8 none pp512 1503.95 ± 9.41 1543.20 ± 7.70
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 8 none tg128 47.64 ± 0.24 49.52 ± 0.20
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 8 none pp512 432.48 ± 6.04 379.39 ± 0.78
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 8 none tg128 21.22 ± 0.09 22.09 ± 0.06
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 8 none pp512 226.01 ± 22.67 512.31 ± 1.32
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 8 none tg128 18.55 ± 0.04 19.35 ± 0.01
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 8 none pp512 688.10 ± 80.92 680.70 ± 3.77
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 8 none tg128 29.94 ± 0.49 33.59 ± 0.28
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 8 none pp512 454.78 ± 10.68 807.77 ± 2.62
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 8 none tg128 33.34 ± 0.22 34.69 ± 0.06
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 8 none pp512 657.45 ± 17.82 567.05 ± 34.67
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 8 none tg128 24.96 ± 0.07 28.04 ± 0.11
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 8 none pp512 147.02 ± 15.35 135.84 ± 9.97
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 8 none tg128 6.78 ± 0.00 8.13 ± 0.03

build: bb157ae (5695)

Lunar Lake results on Windows with icpx2025.1

model size params backend ngl sm test master t/s PR t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none pp512 1900.55 ± 19.40 1966.16 ± 38.87
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none tg128 57.43 ± 0.22 58.26 ± 1.32
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none pp512 1992.80 ± 10.43 2162.81 ± 13.62
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none tg128 44.36 ± 0.13 45.02 ± 0.16
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none pp512 464.20 ± 0.76 534.68 ± 2.49
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none tg128 22.03 ± 0.37 23.03 ± 0.05
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none pp512 472.15 ± 2.47 570.86 ± 0.82
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none tg128 16.55 ± 0.10 17.34 ± 0.23
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none pp512 1187.75 ± 31.26 1555.91 ± 11.14
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none tg128 29.46 ± 0.17 30.23 ± 0.22
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none pp512 592.88 ± 1.62 827.23 ± 1.67
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none tg128 34.56 ± 1.66 36.39 ± 0.15
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none pp512 767.13 ± 9.03 808.58 ± 3.09
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none tg128 21.78 ± 0.24 24.85 ± 0.17
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none pp512 294.83 ± 0.91 319.27 ± 2.97
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none tg128 7.24 ± 0.05 7.83 ± 0.05

build: bb157ae (5695)

A770 results on Linux with icpx 2025.1

model size params backend ngl sm test master t/s PR t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none pp512 4445.46 ± 17.36 4501.42 ± 10.74
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B SYCL 99 none tg128 44.73 ± 0.04 48.49 ± 0.22
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none pp512 4468.23 ± 10.92 4528.13 ± 10.29
qwen2 1.5B Q4_K - Medium 1.04 GiB 1.78 B SYCL 99 none tg128 43.54 ± 0.14 46.82 ± 0.04
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none pp512 1713.05 ± 0.88 1723.20 ± 1.34
llama 7B Q4_0 3.57 GiB 6.74 B SYCL 99 none tg128 33.79 ± 0.05 35.34 ± 0.24
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none pp512 1731.70 ± 1.59 1739.91 ± 0.56
llama 7B Q4_K - Medium 3.80 GiB 6.74 B SYCL 99 none tg128 32.09 ± 0.26 33.89 ± 0.25
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none pp512 3641.67 ± 2.08 3676.88 ± 1.95
gemma2 2B Q4_K - Medium 1.59 GiB 2.61 B SYCL 99 none tg128 38.28 ± 0.33 40.13 ± 0.22
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none pp512 2460.02 ± 1.36 2474.68 ± 3.21
phi3 3B Q4_0 2.03 GiB 3.82 B SYCL 99 none tg128 39.31 ± 0.31 41.40 ± 0.22
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none pp512 2510.45 ± 3.18 2526.88 ± 1.64
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B SYCL 99 none tg128 34.53 ± 0.25 35.84 ± 0.23
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none pp512 1029.68 ± 0.68 1034.40 ± 0.36
llama 34B Q6_K 8.20 GiB 10.73 B SYCL 99 none tg128 17.15 ± 0.12 17.70 ± 0.11

build: bb157ae (5695)

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language labels Jun 17, 2025
Copy link
Collaborator

@Rbiessy Rbiessy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to conflict with the PRs #14158 and #14181. @CISC do you have a rough estimation of when your PRs could be merged? We can wait to merge this one or help with the rebase if needed.

@CISC
Copy link
Collaborator

CISC commented Jun 17, 2025

This is going to conflict with the PRs #14158 and #14181. @CISC do you have a rough estimation of when your PRs could be merged? We can wait to merge this one or help with the rebase if needed.

They are ready once @ggerganov updates Metal implementation.

Edit: ..or not, looks like this PR will change course!

#endif
}

template <int NR = 3, typename L>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <int NR = 3, typename L>
template <int NR, typename L>

You needn't pass the default value here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly elsewhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 49663d5

@Rbiessy
Copy link
Collaborator

Rbiessy commented Jun 18, 2025

LGTM, not approving yet since it should be rebased with master to take into account #14181

@CISC
Copy link
Collaborator

CISC commented Jun 18, 2025

LGTM, not approving yet since it should be rebased with master to take into account #14181

It was merged into #14158 (which has now changed goal and will take a while), do not hold off this on its account. Any help rebasing that after this PR is merged will be much appreciated though! :)

@qnixsynapse
Copy link
Collaborator

I will rebase SYCL code after this gets merged. Feel free to approve.

@s-Nick
Copy link
Collaborator Author

s-Nick commented Jun 20, 2025

@CISC Thank you for the update.
@qnixsynapse I think it should be easy enough for you to rebase on this PR, but if you need help, ping me. I am happy to help rebase the SYCL backend.

@s-Nick s-Nick merged commit 8308f98 into ggml-org:master Jun 20, 2025
47 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 20, 2025
* mamba2-sync: (24 commits)
sync : ggml
Add `ggml_roll` (ggml/1274)
docs : fix the link to llama.h (ggml-org#14293)
CUDA: add conv_2d_transpose (ggml-org#14287)
lint : remove trailing whitepace (ggml-org#14304)
vocab : prevent tokenizer overflow (ggml-org#14301)
sycl: add usage of enqueue_functions extension (ggml-org#14244)
Implement GGML_CPU_ALL_VARIANTS for PowerPC (ggml-org#14286)
llama : improve sep token handling (ggml-org#14272)
cuda : synchronize graph capture and cublas handle destruction (ggml-org#14288)
ggml : fix repack work size for mul_mat_id (ggml-org#14292)
ggml: Update KleidiAI to v1.9.0 (ggml-org#14277)
model : more uniform output id handling (ggml-org#14275)
ubatch : new splitting logic (ggml-org#14217)
CUDA: add conv_2d_dw (ggml-org#14265)
ggml-cpu : remove unnecesary arm feature detection (ggml-org#14281)
gguf-py : make sentencepiece optional (ggml-org#14200)
server : add server parameters for draft model cache type (ggml-org#13782)
build : suppress gcc15 compile warnings (ggml-org#14261)
sycl: Cleanup codepaths in Get Rows in sycl backend (ggml-org#14215)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants