Skip to content

Commit 2025fa6

Browse files
authored
kompute : improve backend to pass test_backend_ops (ggml-org#10542)
* kompute: op_unary: reject unsupported parameters Signed-off-by: Sergio Lopez <[email protected]> * kompute: softmax: implement ALiBi support Signed-off-by: Sergio Lopez <[email protected]> * kompute: rope: implement neox and phi3 support Signed-off-by: Sergio Lopez <[email protected]> * kompute: op_mul_mat_q4_k permutted support Signed-off-by: Sergio Lopez <[email protected]> * kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support Signed-off-by: Sergio Lopez <[email protected]> * kompute: op_mul_mat_f16 permutted support Signed-off-by: Sergio Lopez <[email protected]> * kompute: op_mul_mat_q6_k permutted support Signed-off-by: Sergio Lopez <[email protected]> --------- Signed-off-by: Sergio Lopez <[email protected]>
1 parent c6bc739 commit 2025fa6

16 files changed

+403
-233
lines changed

ggml/src/ggml-kompute/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
105105
kompute-shaders/op_getrows_q4_0.comp
106106
kompute-shaders/op_getrows_q4_1.comp
107107
kompute-shaders/op_getrows_q6_k.comp
108-
kompute-shaders/op_rope_f16.comp
109-
kompute-shaders/op_rope_f32.comp
108+
kompute-shaders/op_rope_norm_f16.comp
109+
kompute-shaders/op_rope_norm_f32.comp
110+
kompute-shaders/op_rope_neox_f16.comp
111+
kompute-shaders/op_rope_neox_f32.comp
110112
kompute-shaders/op_cpy_f16_f16.comp
111113
kompute-shaders/op_cpy_f16_f32.comp
112114
kompute-shaders/op_cpy_f32_f16.comp
@@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
139141
shaderop_getrows_q4_0.h
140142
shaderop_getrows_q4_1.h
141143
shaderop_getrows_q6_k.h
142-
shaderop_rope_f16.h
143-
shaderop_rope_f32.h
144+
shaderop_rope_norm_f16.h
145+
shaderop_rope_norm_f32.h
146+
shaderop_rope_neox_f16.h
147+
shaderop_rope_neox_f32.h
144148
shaderop_cpy_f16_f16.h
145149
shaderop_cpy_f16_f32.h
146150
shaderop_cpy_f32_f16.h

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 115 additions & 61 deletions
Large diffs are not rendered by default.

ggml/src/ggml-kompute/kompute-shaders/common.comp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
44
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
55
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
6+
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
67
#extension GL_EXT_control_flow_attributes: enable
78
#extension GL_KHR_shader_subgroup_arithmetic : require
89
#extension GL_EXT_debug_printf : enable

ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ layout (push_constant) uniform parameter {
2020
uint nb00;
2121
uint nb01;
2222
uint nb02;
23+
uint nb03;
2324
int ne10;
2425
int ne11;
2526
int ne12;
2627
uint nb10;
2728
uint nb11;
2829
uint nb12;
30+
uint nb13;
2931
int ne0;
3032
int ne1;
3133
uint r2;
@@ -42,7 +44,7 @@ void main() {
4244
const uint i12 = im%pcs.ne12;
4345
const uint i13 = im/pcs.ne12;
4446

45-
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
47+
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
4648

4749
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
4850

@@ -52,7 +54,7 @@ void main() {
5254
break;
5355
}
5456

55-
const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
57+
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
5658

5759
float sumf = 0;
5860
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {

ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
2424
int ne01;
2525
int ne02;
2626
int ne12;
27-
int r2;
28-
int r3;
27+
uint nb01;
28+
uint nb02;
29+
uint nb03;
30+
uint nb11;
31+
uint nb12;
32+
uint nb13;
33+
uint r2;
34+
uint r3;
2935
} pcs;
3036

3137
void main() {
@@ -50,10 +56,11 @@ void main() {
5056
const uint i12 = im%pcs.ne12;
5157
const uint i13 = im/pcs.ne12;
5258

53-
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
59+
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
60+
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
5461

55-
const uint xblk = ib_row + offset0 + pcs.inAOff;
56-
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
62+
const uint xblk = offset0 + pcs.inAOff;
63+
const uint y = (offset1 / 4) + pcs.inBOff;
5764

5865
float yl[16];
5966
float yh[16];
@@ -74,7 +81,7 @@ void main() {
7481
}
7582

7683
for (int row = 0; row < N_DST; row++) {
77-
uint row_idx = row * nb;
84+
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
7885

7986
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
8087
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);

ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
2121
int ne0;
2222
int ne1;
2323
int ne01;
24-
int gqa;
24+
int ne02;
25+
int ne12;
26+
uint nb01;
27+
uint nb02;
28+
uint nb03;
29+
uint nb11;
30+
uint nb12;
31+
uint nb13;
32+
uint r2;
33+
uint r3;
2534
} pcs;
2635

2736
void main() {
@@ -34,12 +43,15 @@ void main() {
3443

3544
const uint r0 = gl_WorkGroupID.x;
3645
const uint r1 = gl_WorkGroupID.y;
37-
const uint r2 = gl_WorkGroupID.z;
46+
const uint im = gl_WorkGroupID.z;
3847

3948
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
40-
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
41-
const uint x = row * nb + offset0; // Based from inA without base offset
42-
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
49+
50+
const uint i12 = im%pcs.ne12;
51+
const uint i13 = im/pcs.ne12;
52+
53+
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
54+
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
4355

4456
float sumf = 0;
4557

@@ -89,6 +101,6 @@ void main() {
89101

90102
const float tot = subgroupAdd(sumf);
91103
if (subgroupElect()) {
92-
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
104+
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
93105
}
94106
}

ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ void main() {
1414
const uint i12 = im%pcs.ne12;
1515
const uint i13 = im/pcs.ne12;
1616

17-
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
17+
// pointers to src0 rows
18+
uint ax[N_ROWS];
19+
for (int row = 0; row < N_ROWS; ++row) {
20+
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
21+
22+
ax[row] = offset0 + pcs.inAOff;
23+
}
1824

19-
const uint x = offset0; // Based from inA without base offset
20-
const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
25+
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
2126

2227
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
2328

@@ -32,8 +37,7 @@ void main() {
3237

3338
for (uint ib = ix; ib < nb; ib += 16) {
3439
for (int row = 0; row < N_ROWS; row++) {
35-
const uint block_index = x + ib + row * nb;
36-
sumf[row] += block_q_n_dot_y(block_index, yb, il);
40+
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
3741
}
3842

3943
yb += BLOCKS_IN_QUANT * 16;

ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
layout(local_size_x_id = 0) in;
2-
layout(local_size_y = 1) in;
2+
layout(local_size_y = 8) in;
33
layout(local_size_z = 1) in;
44

55
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
@@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
1717
int ne12;
1818
int ne0;
1919
int ne1;
20+
uint nb01;
21+
uint nb02;
22+
uint nb03;
23+
uint nb11;
24+
uint nb12;
25+
uint nb13;
2026
uint r2;
2127
uint r3;
2228
} pcs;

ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp

Lines changed: 0 additions & 73 deletions
This file was deleted.

ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp

Lines changed: 0 additions & 73 deletions
This file was deleted.

0 commit comments

Comments
 (0)