Skip to content

Commit c4f4966

Browse files
authored
metal : fix kernel_norm (fixes Falcon on Metal) (#3057)
* metal : fix kernel_norm ggml-ci * metal : put warning in kernel_norm to not combine the loops * metal : restore original F16 mat-vec multiplication It works after the norm fixes * common : don't do warm-up with more than n_batch tokens (close #3058) ggml-ci * metal : minor
1 parent fec2fb1 commit c4f4966

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
773773
LOG("warming up the model with an empty run\n");
774774

775775
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
776-
llama_eval(lctx, tmp.data(), tmp.size(), 0, params.n_threads);
776+
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
777777
llama_reset_timings(lctx);
778778
}
779779

ggml-metal.metal

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -220,27 +220,32 @@ kernel void kernel_norm(
220220
}
221221
threadgroup_barrier(mem_flags::mem_threadgroup);
222222
}
223-
//// broadcast
224-
//if (tpitg == 0) {
225-
// sum[0] /= ne00;
226-
//}
227-
//threadgroup_barrier(mem_flags::mem_threadgroup);
223+
// broadcast
224+
if (tpitg == 0) {
225+
sum[0] /= ne00;
226+
}
227+
threadgroup_barrier(mem_flags::mem_threadgroup);
228228
const float mean = sum[0];
229229

230-
// recenter and VARIANCE
230+
// recenter
231231
device float * y = dst + tgpig*ne00;
232-
sum[tpitg] = 0.0f;
233232
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
234233
y[i00] = x[i00] - mean;
234+
}
235+
236+
// VARIANCE
237+
// parallel sum
238+
//
239+
// WARNING: combining this loop with the one above will give you wrong results for nth == 256
240+
// I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
241+
// Tested with:
242+
// ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
243+
//
244+
sum[tpitg] = 0.0f;
245+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
235246
sum[tpitg] += y[i00] * y[i00];
236247
}
237248

238-
//// VARIANCE
239-
//// parallel sum
240-
//sum[tpitg] = 0.0f;
241-
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
242-
// sum[tpitg] += y[i00] * y[i00];
243-
//}
244249
// reduce
245250
threadgroup_barrier(mem_flags::mem_threadgroup);
246251
for (uint i = ntg/2; i > 0; i /= 2) {
@@ -249,11 +254,11 @@ kernel void kernel_norm(
249254
}
250255
threadgroup_barrier(mem_flags::mem_threadgroup);
251256
}
252-
//// broadcast
253-
//if (tpitg == 0) {
254-
// sum[0] /= ne00;
255-
//}
256-
//threadgroup_barrier(mem_flags::mem_threadgroup);
257+
// broadcast
258+
if (tpitg == 0) {
259+
sum[0] /= ne00;
260+
}
261+
threadgroup_barrier(mem_flags::mem_threadgroup);
257262
const float variance = sum[0];
258263

259264
const float scale = 1.0f/sqrt(variance + eps);
@@ -262,7 +267,6 @@ kernel void kernel_norm(
262267
}
263268
}
264269

265-
266270
kernel void kernel_rms_norm(
267271
device const void * src0,
268272
device float * dst,
@@ -630,7 +634,6 @@ kernel void kernel_mul_mat_f16_f32(
630634
}
631635
}
632636
}
633-
634637
}
635638

636639
kernel void kernel_alibi_f32(

0 commit comments

Comments
 (0)