Skip to content

Commit 80b5b51

Browse files
committed
metal : reorder write loop
1 parent e40b85a commit 80b5b51

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

ggml/src/ggml-metal.metal

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6403,11 +6403,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
64036403

64046404
threadgroup_barrier(mem_flags::mem_threadgroup);
64056405

6406-
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
64076406
if (sgitg == 0) {
6408-
for (int i = 0; i < n_rows; i++) {
6409-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6410-
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6407+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6408+
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
6409+
device float4 * D4 = (device float4 *) D;
6410+
6411+
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6412+
threadgroup float4 * C4 = (threadgroup float4 *) C;
6413+
6414+
int i = 0;
6415+
for (; i < n_rows/4; i++) {
6416+
*(D4 + i) = *(C4 + i);
6417+
}
6418+
6419+
i *= 4;
6420+
for (; i < n_rows; i++) {
6421+
*(D + i) = *(C + i);
64116422
}
64126423
}
64136424
}

0 commit comments

Comments
 (0)