File tree Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change @@ -6403,11 +6403,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
6403
6403
6404
6404
threadgroup_barrier (mem_flags::mem_threadgroup);
6405
6405
6406
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
6407
6406
if (sgitg == 0 ) {
6408
- for (int i = 0 ; i < n_rows; i++) {
6409
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6410
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6407
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6408
+ device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
6409
+ device float4 * D4 = (device float4 *) D;
6410
+
6411
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6412
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
6413
+
6414
+ int i = 0 ;
6415
+ for (; i < n_rows/4 ; i++) {
6416
+ *(D4 + i) = *(C4 + i);
6417
+ }
6418
+
6419
+ i *= 4 ;
6420
+ for (; i < n_rows; i++) {
6421
+ *(D + i) = *(C + i);
6411
6422
}
6412
6423
}
6413
6424
}
You can’t perform that action at this time.
0 commit comments