Skip to content

kompute: improve backend to pass test_backend_ops #10542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ggml/src/ggml-kompute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp
kompute-shaders/op_getrows_q6_k.comp
kompute-shaders/op_rope_f16.comp
kompute-shaders/op_rope_f32.comp
kompute-shaders/op_rope_norm_f16.comp
kompute-shaders/op_rope_norm_f32.comp
kompute-shaders/op_rope_neox_f16.comp
kompute-shaders/op_rope_neox_f32.comp
kompute-shaders/op_cpy_f16_f16.comp
kompute-shaders/op_cpy_f16_f32.comp
kompute-shaders/op_cpy_f32_f16.comp
Expand Down Expand Up @@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h
shaderop_rope_f16.h
shaderop_rope_f32.h
shaderop_rope_norm_f16.h
shaderop_rope_norm_f32.h
shaderop_rope_neox_f16.h
shaderop_rope_neox_f32.h
shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h
Expand Down
176 changes: 115 additions & 61 deletions ggml/src/ggml-kompute/ggml-kompute.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ggml/src/ggml-kompute/kompute-shaders/common.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
#extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ layout (push_constant) uniform parameter {
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne10;
int ne11;
int ne12;
uint nb10;
uint nb11;
uint nb12;
uint nb13;
int ne0;
int ne1;
uint r2;
Expand All @@ -42,7 +44,7 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;

const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;

const uint x = offset0 / 2 + pcs.inAOff; // Based from inA

Expand All @@ -52,7 +54,7 @@ void main() {
break;
}

const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;

float sumf = 0;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
Expand Down
19 changes: 13 additions & 6 deletions ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
int ne01;
int ne02;
int ne12;
int r2;
int r3;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;

void main() {
Expand All @@ -50,10 +56,11 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;

const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;

const uint xblk = ib_row + offset0 + pcs.inAOff;
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
const uint xblk = offset0 + pcs.inAOff;
const uint y = (offset1 / 4) + pcs.inBOff;

float yl[16];
float yh[16];
Expand All @@ -74,7 +81,7 @@ void main() {
}

for (int row = 0; row < N_DST; row++) {
uint row_idx = row * nb;
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);

uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
Expand Down
24 changes: 18 additions & 6 deletions ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
int ne0;
int ne1;
int ne01;
int gqa;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;

void main() {
Expand All @@ -34,12 +43,15 @@ void main() {

const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint r2 = gl_WorkGroupID.z;
const uint im = gl_WorkGroupID.z;

const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
const uint x = row * nb + offset0; // Based from inA without base offset
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB

const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;

const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;

float sumf = 0;

Expand Down Expand Up @@ -89,6 +101,6 @@ void main() {

const float tot = subgroupAdd(sumf);
if (subgroupElect()) {
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
}
}
14 changes: 9 additions & 5 deletions ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;

const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
// pointers to src0 rows
uint ax[N_ROWS];
for (int row = 0; row < N_ROWS; ++row) {
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);

ax[row] = offset0 + pcs.inAOff;
}

const uint x = offset0; // Based from inA without base offset
const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;

float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};

Expand All @@ -32,8 +37,7 @@ void main() {

for (uint ib = ix; ib < nb; ib += 16) {
for (int row = 0; row < N_ROWS; row++) {
const uint block_index = x + ib + row * nb;
sumf[row] += block_q_n_dot_y(block_index, yb, il);
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
}

yb += BLOCKS_IN_QUANT * 16;
Expand Down
8 changes: 7 additions & 1 deletion ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
layout(local_size_x_id = 0) in;
layout(local_size_y = 1) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;

layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
Expand All @@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
int ne12;
int ne0;
int ne1;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
73 changes: 0 additions & 73 deletions ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp

This file was deleted.

73 changes: 0 additions & 73 deletions ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp

This file was deleted.

Loading
Loading