Skip to content

Commit 2a24994

Browse files
committed
mtl : initial mul_mat Q4 kernel (wrong results)
1 parent 64afc0b commit 2a24994

File tree

3 files changed

+144
-19
lines changed

3 files changed

+144
-19
lines changed

examples/mtl/mtl.m

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838

3939
id<MTLFunction> function_rms_norm;
4040
id<MTLComputePipelineState> pipeline_rms_norm;
41+
42+
id<MTLFunction> function_mul_mat_q4_0;
43+
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
4144
};
4245

4346
// MSL code
@@ -141,6 +144,10 @@
141144
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
142145
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
143146
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
147+
148+
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
149+
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
150+
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
144151
}
145152

146153
// MTLBuffer approach
@@ -317,7 +324,9 @@ int llama_mtl_eval(
317324
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
318325
} break;
319326
case GGML_OP_MUL_MAT:
320-
{
327+
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) {
328+
// for F32 x F32 we use MPS
329+
321330
if (encoder != nil) {
322331
[encoder endEncoding];
323332
encoder = nil;
@@ -354,6 +363,43 @@ int llama_mtl_eval(
354363
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
355364

356365
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
366+
} else {
367+
// for Q4 x F32 we use custom kernel
368+
369+
if (encoder == nil) {
370+
encoder = [command_buffer computeCommandEncoder];
371+
}
372+
373+
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
374+
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
375+
376+
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
377+
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
378+
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
379+
380+
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
381+
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
382+
383+
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
384+
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
385+
386+
const int64_t ncols = gf->nodes[i]->ne[0];
387+
const int64_t nrows = gf->nodes[i]->ne[1];
388+
389+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
390+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
391+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
392+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
393+
[encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3];
394+
[encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4];
395+
[encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5];
396+
[encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6];
397+
[encoder setBytes:&ncols length:sizeof(ncols) atIndex:7];
398+
[encoder setBytes:&nrows length:sizeof(nrows) atIndex:8];
399+
400+
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows);
401+
402+
[encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
357403
} break;
358404
case GGML_OP_GET_ROWS:
359405
{

examples/mtl/mtl.metal

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ using namespace metal;
77
#define QK4_0 32
88
#define QR4_0 2
99
typedef struct {
10-
half d; // delta
11-
uint8_t qs[QK4_0 / 2]; // nibbles / quants
10+
half d; // delta
11+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
1212
} block_q4_0;
1313

1414
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
@@ -38,8 +38,8 @@ kernel void kernel_add(
3838
device const float * src0,
3939
device const float * src1,
4040
device float * dst,
41-
uint gid[[thread_position_in_grid]]) {
42-
dst[gid] = src0[gid] + src1[gid];
41+
uint tpig[[thread_position_in_grid]]) {
42+
dst[tpig] = src0[tpig] + src1[tpig];
4343
}
4444

4545
// assumption: src1 is a row
@@ -49,15 +49,15 @@ kernel void kernel_mul(
4949
device const float * src1,
5050
device float * dst,
5151
constant int64_t & ne00,
52-
uint gid[[thread_position_in_grid]]) {
53-
dst[gid] = src0[gid] * src1[gid % ne00];
52+
uint tpig[[thread_position_in_grid]]) {
53+
dst[tpig] = src0[tpig] * src1[tpig % ne00];
5454
}
5555

5656
kernel void kernel_relu(
5757
device const float * src0,
5858
device float * dst,
59-
uint gid[[thread_position_in_grid]]) {
60-
dst[gid] = max(0.0f, src0[gid]);
59+
uint tpig[[thread_position_in_grid]]) {
60+
dst[tpig] = max(0.0f, src0[tpig]);
6161
}
6262

6363
// TODO: broken
@@ -85,8 +85,8 @@ kernel void kernel_get_rows_q4_0(
8585
constant int64_t & ne00,
8686
constant uint64_t & nb01,
8787
constant uint64_t & nb1,
88-
uint gid[[thread_position_in_grid]]) {
89-
const int i = gid;
88+
uint tpig[[thread_position_in_grid]]) {
89+
const int i = tpig;
9090
const int r = ((device int32_t *) src1)[i];
9191

9292
dequantize_row_q4_0(
@@ -100,8 +100,8 @@ kernel void kernel_rms_norm(
100100
constant int64_t & ne00,
101101
constant uint64_t & nb01,
102102
constant float & eps,
103-
uint gid[[thread_position_in_grid]]) {
104-
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
103+
uint tpig[[thread_position_in_grid]]) {
104+
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01);
105105

106106
float sum = 0.0f;
107107
for (int i00 = 0; i00 < ne00; i00++) {
@@ -111,8 +111,84 @@ kernel void kernel_rms_norm(
111111
const float mean = sum/ne00;
112112
const float scale = 1.0f/sqrt(mean + eps);
113113

114-
device float * y = dst + gid*ne00;
114+
device float * y = dst + tpig*ne00;
115115
for (int i00 = 0; i00 < ne00; i00++) {
116116
y[i00] = x[i00] * scale;
117117
}
118118
}
119+
120+
kernel void kernel_mul_mat_q4_0(
121+
device const void * src0,
122+
device const float * src1,
123+
device float * dst,
124+
constant int64_t & ne00,
125+
constant int64_t & ne01,
126+
constant int64_t & ne10,
127+
constant int64_t & ne11,
128+
constant int64_t & ne0,
129+
constant int64_t & ne1,
130+
uint2 tgpig[[threadgroup_position_in_grid]],
131+
uint2 tpig[[thread_position_in_grid]],
132+
uint2 tpitg[[thread_position_in_threadgroup]],
133+
uint2 tptg[[threads_per_threadgroup]]) {
134+
const int64_t r0 = tgpig.x;
135+
const int64_t r1 = tgpig.y;
136+
137+
const int qk = QK4_0;
138+
const int nb = ne00/qk;
139+
140+
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
141+
device const float * y = (device const float *) (src1) + r1*ne10;
142+
143+
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
144+
sum[tpitg.x] = 0.0f;
145+
146+
for (int i = 0; i < nb; i += tptg.x) {
147+
device const uint4 * x0p = (device const uint4 *) (x + i);
148+
device const float4 * y0p = (device const float4 *) (y + i*qk);
149+
150+
const uint4 x0 = *x0p;
151+
152+
const uint4 x0l = x0 & uint4(0x0F0F0F0F);
153+
const uint4 x0h = x0 >> 4;
154+
155+
const int4 x0ls = as_type<int4>(x0l) - int4(8);
156+
const int4 x0hs = as_type<int4>(x0h) - int4(8);
157+
158+
thread const uchar * x0lsb = (thread const uchar *) &x0ls;
159+
thread const uchar * x0hsb = (thread const uchar *) &x0hs;
160+
161+
const float4 y00 = *(y0p + 0);
162+
const float4 y01 = *(y0p + 1);
163+
const float4 y02 = *(y0p + 2);
164+
const float4 y03 = *(y0p + 3);
165+
const float4 y04 = *(y0p + 4);
166+
const float4 y05 = *(y0p + 5);
167+
const float4 y06 = *(y0p + 6);
168+
const float4 y07 = *(y0p + 7);
169+
170+
const float d = (x + i)->d;
171+
172+
sum[tpitg.x] += (
173+
x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] +
174+
x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] +
175+
x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] +
176+
x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] +
177+
x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] +
178+
x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] +
179+
x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] +
180+
x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3]
181+
) * d;
182+
}
183+
184+
// accumulate the sum from all threads in the threadgroup
185+
threadgroup_barrier(mem_flags::mem_threadgroup);
186+
for (uint i = tptg.x/2; i > 0; i /= 2) {
187+
if (tpitg.x < i) {
188+
sum[tpitg.x] += sum[tpitg.x + i];
189+
}
190+
threadgroup_barrier(mem_flags::mem_threadgroup);
191+
}
192+
193+
dst[r1*ne0 + r0] = sum[0];
194+
}

llama.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,16 +1266,19 @@ static bool llama_eval_internal(
12661266

12671267
// cur = cur*attention_norm(broadcasted)
12681268
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
1269-
// TODO: TMP !!!!
1270-
if (il == 0) {
1271-
ggml_set_name(cur, "mtl-check");
1272-
}
12731269
}
12741270

12751271
// self-attention
12761272
{
1273+
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
1274+
// TODO: TMP !!!!
1275+
if (il == 0) {
1276+
ggml_set_name(x, "mtl-check");
1277+
}
1278+
12771279
// compute Q and K and RoPE them
1278-
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
1280+
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
1281+
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
12791282
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
12801283
ggml_set_name(Qcur, "Qcur");
12811284
ggml_set_name(Kcur, "Kcur");

0 commit comments

Comments
 (0)