Skip to content

Commit 5e1c408

Browse files
committed
metal : fix kernel_norm
ggml-ci
1 parent fec2fb1 commit 5e1c408

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

ggml-metal.m

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,8 +995,12 @@ void ggml_metal_graph_compute(
995995
else if (src0t == GGML_TYPE_Q6_K) {
996996
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
997997
} else {
998-
int64_t ny = (ne11 + 3)/4;
999-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
998+
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
999+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1000+
1001+
// TODO: this breaks for Q4_0 - understand why and fix it
1002+
//int64_t ny = (ne11 + 3)/4;
1003+
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
10001004
}
10011005
}
10021006
} break;

ggml-metal.metal

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ kernel void kernel_soft_max(
111111
uint3 tgpig[[threadgroup_position_in_grid]],
112112
uint3 tpitg[[thread_position_in_threadgroup]],
113113
uint3 ntg[[threads_per_threadgroup]]) {
114-
const int64_t i03 = tgpig[2];
114+
const int64_t i03 = tgpig[2];
115115
const int64_t i02 = tgpig[1];
116116
const int64_t i01 = tgpig[0];
117117

@@ -220,27 +220,26 @@ kernel void kernel_norm(
220220
}
221221
threadgroup_barrier(mem_flags::mem_threadgroup);
222222
}
223-
//// broadcast
224-
//if (tpitg == 0) {
225-
// sum[0] /= ne00;
226-
//}
227-
//threadgroup_barrier(mem_flags::mem_threadgroup);
223+
// broadcast
224+
if (tpitg == 0) {
225+
sum[0] /= ne00;
226+
}
227+
threadgroup_barrier(mem_flags::mem_threadgroup);
228228
const float mean = sum[0];
229229

230-
// recenter and VARIANCE
230+
// recenter
231231
device float * y = dst + tgpig*ne00;
232-
sum[tpitg] = 0.0f;
233232
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
234233
y[i00] = x[i00] - mean;
234+
}
235+
236+
// VARIANCE
237+
// parallel sum
238+
sum[tpitg] = 0.0f;
239+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
235240
sum[tpitg] += y[i00] * y[i00];
236241
}
237242

238-
//// VARIANCE
239-
//// parallel sum
240-
//sum[tpitg] = 0.0f;
241-
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
242-
// sum[tpitg] += y[i00] * y[i00];
243-
//}
244243
// reduce
245244
threadgroup_barrier(mem_flags::mem_threadgroup);
246245
for (uint i = ntg/2; i > 0; i /= 2) {
@@ -249,11 +248,11 @@ kernel void kernel_norm(
249248
}
250249
threadgroup_barrier(mem_flags::mem_threadgroup);
251250
}
252-
//// broadcast
253-
//if (tpitg == 0) {
254-
// sum[0] /= ne00;
255-
//}
256-
//threadgroup_barrier(mem_flags::mem_threadgroup);
251+
// broadcast
252+
if (tpitg == 0) {
253+
sum[0] /= ne00;
254+
}
255+
threadgroup_barrier(mem_flags::mem_threadgroup);
257256
const float variance = sum[0];
258257

259258
const float scale = 1.0f/sqrt(variance + eps);
@@ -262,7 +261,6 @@ kernel void kernel_norm(
262261
}
263262
}
264263

265-
266264
kernel void kernel_rms_norm(
267265
device const void * src0,
268266
device float * dst,
@@ -630,7 +628,6 @@ kernel void kernel_mul_mat_f16_f32(
630628
}
631629
}
632630
}
633-
634631
}
635632

636633
kernel void kernel_alibi_f32(

0 commit comments

Comments
 (0)