Skip to content

Commit 7844a11

Browse files
committed
1 parent 54a3e73 commit 7844a11

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4447,7 +4447,7 @@ void kernel_mul_mv_q2_K_f32_impl(
44474447

44484448
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
44494449

4450-
for (int row = 0; row < N_DST; ++row) {
4450+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
44514451
all_sum = simd_sum(sumf[row]);
44524452
if (tiisg == 0) {
44534453
dst_f32[first_row + row] = all_sum;
@@ -4613,7 +4613,7 @@ void kernel_mul_mv_q3_K_f32_impl(
46134613
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
46144614

46154615
if (tiisg == 0) {
4616-
for (int row = 0; row < 2; ++row) {
4616+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
46174617
dst_f32[first_row + row] = sumf1[row];
46184618
}
46194619
}
@@ -4729,7 +4729,7 @@ void kernel_mul_mv_q4_K_f32_impl(
47294729

47304730
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
47314731

4732-
for (int row = 0; row < N_DST; ++row) {
4732+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
47334733
all_sum = simd_sum(sumf[row]);
47344734
if (tiisg == 0) {
47354735
dst_f32[first_row + row] = all_sum;
@@ -4861,7 +4861,7 @@ void kernel_mul_mv_q5_K_f32_impl(
48614861

48624862
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
48634863

4864-
for (int row = 0; row < 2; ++row) {
4864+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
48654865
const float tot = simd_sum(sumf[row]);
48664866
if (tiisg == 0) {
48674867
dst_f32[first_row + row] = tot;
@@ -4906,6 +4906,10 @@ void kernel_mul_mv_q6_K_f32_impl(
49064906

49074907
const int row = 2*r0 + sgitg;
49084908

4909+
if (row >= args.ne0) {
4910+
return;
4911+
}
4912+
49094913
const uint i12 = im%args.ne12;
49104914
const uint i13 = im/args.ne12;
49114915

@@ -5061,7 +5065,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
50615065

50625066
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
50635067

5064-
for (int row = 0; row < N_DST; ++row) {
5068+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
50655069
all_sum = simd_sum(sumf[row]);
50665070
if (tiisg == 0) {
50675071
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5179,7 +5183,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
51795183

51805184
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
51815185

5182-
for (int row = 0; row < N_DST; ++row) {
5186+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
51835187
all_sum = simd_sum(sumf[row]);
51845188
if (tiisg == 0) {
51855189
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5289,7 +5293,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
52895293

52905294
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
52915295

5292-
for (int row = 0; row < N_DST; ++row) {
5296+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
52935297
all_sum = simd_sum(sumf[row]);
52945298
if (tiisg == 0) {
52955299
dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5401,7 +5405,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
54015405

54025406
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
54035407

5404-
for (int row = 0; row < N_DST; ++row) {
5408+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
54055409
all_sum = simd_sum(sumf[row]);
54065410
if (tiisg == 0) {
54075411
dst_f32[first_row + row] = all_sum;
@@ -5514,7 +5518,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
55145518

55155519
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
55165520

5517-
for (int row = 0; row < N_DST; ++row) {
5521+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
55185522
all_sum = simd_sum(sumf[row]);
55195523
if (tiisg == 0) {
55205524
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5614,7 +5618,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
56145618

56155619
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
56165620

5617-
for (int row = 0; row < N_DST; ++row) {
5621+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
56185622
all_sum = simd_sum(sumf[row]);
56195623
if (tiisg == 0) {
56205624
dst_f32[first_row + row] = all_sum;
@@ -5709,7 +5713,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
57095713

57105714
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
57115715

5712-
for (int row = 0; row < N_DST; ++row) {
5716+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
57135717
all_sum = simd_sum(sumf[row]);
57145718
if (tiisg == 0) {
57155719
dst_f32[first_row + row] = all_sum;
@@ -5799,7 +5803,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
57995803

58005804
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
58015805

5802-
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
5806+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
58035807
all_sum = simd_sum(sumf[row]);
58045808
if (tiisg == 0) {
58055809
dst_f32[first_row + row] = all_sum;
@@ -5888,7 +5892,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
58885892

58895893
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
58905894

5891-
for (int row = 0; row < 2; ++row) {
5895+
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
58925896
all_sum = simd_sum(sumf[row]);
58935897
if (tiisg == 0) {
58945898
dst_f32[first_row + row] = all_sum;

0 commit comments

Comments
 (0)