@@ -486,7 +486,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
486
486
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0 (int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
487
487
GGML_UNUSED (x_qh); GGML_UNUSED (x_sc);
488
488
489
- __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
489
+ __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y* 4 ];
490
490
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
491
491
492
492
*x_ql = tile_x_qs;
@@ -519,7 +519,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
519
519
520
520
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
521
521
522
- x_ql[i * (WARP_SIZE + 1 ) + k] = get_int_from_int8 (bxi->qs , kqsx);
522
+ x_ql[i * (WARP_SIZE + 4 ) + k] = get_int_from_int8 (bxi->qs , kqsx);
523
523
}
524
524
525
525
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -547,9 +547,20 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
547
547
const float * x_dmf = (const float *) x_dm;
548
548
const float * y_df = (const float *) y_ds;
549
549
550
- return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
551
- (&x_ql[i * (WARP_SIZE + 1 ) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
552
- y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
550
+ const int * v = &x_ql[i * (WARP_SIZE + 1 ) + k];
551
+ const int * u = &y_qs[j * WARP_SIZE + k];
552
+ const float d8_0 = x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0];
553
+ const float d8_1 = y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1];
554
+
555
+ int sumi = 0 ;
556
+
557
+ #pragma unroll
558
+ for (int i = 0 ; i < VDR_Q8_0_Q8_1_MMQ; ++i) {
559
+ // SIMD dot product of quantized values
560
+ sumi = __dp4a (v[i], u[i], sumi);
561
+ }
562
+
563
+ return d8_0*d8_1 * sumi;
553
564
}
554
565
555
566
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K (int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -1066,6 +1077,167 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
1066
1077
return vec_dot_q6_K_q8_1_impl_mmq (&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
1067
1078
}
1068
1079
1080
+ template <int qk, int qr, int qi, bool need_sum, typename block_q_t , int mmq_x, int mmq_y, int nwarps,
1081
+ allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
1082
+ static __device__ __forceinline__ void mul_mat_q_test (
1083
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1084
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
1085
+
1086
+ const block_q_t * x = (const block_q_t *) vx;
1087
+ const block_q8_1 * y = (const block_q8_1 *) vy;
1088
+
1089
+ const int blocks_per_row_x = ncols_x / qk;
1090
+ const int blocks_per_col_y = nrows_y / QK8_1;
1091
+ const int blocks_per_warp = WARP_SIZE / qi;
1092
+
1093
+ const int & ncols_dst = ncols_y;
1094
+
1095
+ const int row_dst_0 = blockIdx .x *mmq_y;
1096
+ const int & row_x_0 = row_dst_0;
1097
+
1098
+ const int col_dst_0 = blockIdx .y *mmq_x;
1099
+ const int & col_y_0 = col_dst_0;
1100
+
1101
+ int * tile_x_ql = nullptr ;
1102
+ half2 * tile_x_dm = nullptr ;
1103
+ int * tile_x_qh = nullptr ;
1104
+ int * tile_x_sc = nullptr ;
1105
+
1106
+ allocate_tiles (&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
1107
+
1108
+ __shared__ int tile_y_qs[mmq_x * WARP_SIZE + mmq_x*4 ];
1109
+ __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
1110
+
1111
+ static_assert (mmq_x % (8 *nwarps) == 0 , " assert" );
1112
+ float sum[mmq_x/(8 *nwarps)][mmq_y/16 ][4 ] = {{{0 .0f }}};
1113
+
1114
+ for (int ib0 = 0 ; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1115
+
1116
+ load_tiles (x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1117
+ threadIdx .y , nrows_x-row_x_0-1 , threadIdx .x , blocks_per_row_x);
1118
+
1119
+ #pragma unroll
1120
+ for (int ir = 0 ; ir < qr; ++ir) {
1121
+ const int kqs = ir*WARP_SIZE + threadIdx .x ;
1122
+ const int kbxd = kqs / QI8_1;
1123
+
1124
+ #pragma unroll
1125
+ for (int i = 0 ; i < mmq_x; i += nwarps) {
1126
+ const int col_y_eff = min (col_y_0 + threadIdx .y + i, ncols_y-1 ); // to prevent out-of-bounds memory accesses
1127
+
1128
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1129
+
1130
+ const int index_y = (threadIdx .y + i) * (WARP_SIZE + 4 ) + kqs % WARP_SIZE;
1131
+ tile_y_qs[index_y] = get_int_from_int8_aligned (by0->qs , threadIdx .x % QI8_1);
1132
+ }
1133
+
1134
+ #pragma unroll
1135
+ for (int ids0 = 0 ; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1136
+ const int ids = (ids0 + threadIdx .y * QI8_1 + threadIdx .x / (WARP_SIZE/QI8_1)) % mmq_x;
1137
+ const int kby = threadIdx .x % (WARP_SIZE/QI8_1);
1138
+ const int col_y_eff = min (col_y_0 + ids, ncols_y-1 );
1139
+
1140
+ // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1141
+ const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds ;
1142
+ half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
1143
+ if (need_sum) {
1144
+ *dsi_dst = *dsi_src;
1145
+ } else {
1146
+ float * dfi_dst = (float *) dsi_dst;
1147
+ *dfi_dst = __low2float (*dsi_src);
1148
+ }
1149
+ }
1150
+
1151
+ __syncthreads ();
1152
+
1153
+ const float * x_dmf = (const float *) tile_x_dm;
1154
+ const float * y_df = (const float *) tile_y_ds;
1155
+
1156
+ static_assert (!need_sum);
1157
+ static_assert (vdr == 32 /sizeof (int ));
1158
+ #pragma unroll
1159
+ for (int k0 = ir*WARP_SIZE/qr; k0 < (ir+1 )*WARP_SIZE/qr; k0 += vdr) {
1160
+ #pragma unroll
1161
+ for (int j00 = 0 ; j00 < mmq_x; j00 += 8 *nwarps) {
1162
+ const int j0 = j00 + 8 *threadIdx .y ;
1163
+ #pragma unroll
1164
+ for (int i0 = 0 ; i0 < mmq_y; i0 += 16 ) {
1165
+ int v[4 ];
1166
+ #pragma unroll
1167
+ for (int l = 0 ; l < 4 ; ++l) {
1168
+ v[l] = tile_x_ql[(i0 + (l%2 )*8 + threadIdx .x /4 ) * (WARP_SIZE + 4 ) + k0 + (l/2 )*4 + threadIdx .x %4 ];
1169
+ }
1170
+ int u[2 ];
1171
+ #pragma unroll
1172
+ for (int l = 0 ; l < 2 ; ++l) {
1173
+ u[l] = tile_y_qs[(j0 + threadIdx .x /4 ) * (WARP_SIZE + 4 ) + k0 + 4 *l + threadIdx .x %4 ];
1174
+ }
1175
+
1176
+ int sumi[4 ] = {0 };
1177
+ #if __CUDA_ARCH__ >= CC_AMPERE
1178
+ asm (" mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1179
+ : " +r" (sumi[0 ]), " +r" (sumi[1 ]), " +r" (sumi[2 ]), " +r" (sumi[3 ])
1180
+ : " r" (v[0 ]), " r" (v[1 ]), " r" (v[2 ]), " r" (v[3 ]), " r" (u[0 ]), " r" (u[1 ]));
1181
+ #else
1182
+ asm (" mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1183
+ : " +r" (sumi[0 ]), " +r" (sumi[1 ]), " +r" (sumi[2 ]), " +r" (sumi[3 ])
1184
+ : " r" (v[0 ]), " r" (v[1 ]), " r" (u[0 ]));
1185
+ asm (" mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1186
+ : " +r" (sumi[0 ]), " +r" (sumi[1 ]), " +r" (sumi[2 ]), " +r" (sumi[3 ])
1187
+ : " r" (v[2 ]), " r" (v[3 ]), " r" (u[1 ]));
1188
+ #endif // __CUDA_ARCH__ >= CC_AMPERE
1189
+
1190
+ float d8_0[2 ];
1191
+ #pragma unroll
1192
+ for (int l = 0 ; l < 2 ; ++l) {
1193
+ const int i = i0 + 8 *l + threadIdx .x /4 ;
1194
+ d8_0[l] = x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
1195
+ }
1196
+ float d8_1[2 ];
1197
+ #pragma unroll
1198
+ for (int l = 0 ; l < 2 ; ++l) {
1199
+ const int j = j0 + 2 *(threadIdx .x %4 ) + l;
1200
+ d8_1[l] = y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1];
1201
+ }
1202
+
1203
+ #pragma unroll
1204
+ for (int l = 0 ; l < 4 ; ++l) {
1205
+ sum[j00/(8 *nwarps)][i0/16 ][l] += d8_0[l/2 ]*d8_1[l%2 ] * sumi[l];
1206
+ }
1207
+ }
1208
+ }
1209
+ }
1210
+
1211
+ __syncthreads ();
1212
+ }
1213
+ }
1214
+
1215
+ #pragma unroll
1216
+ for (int j00 = 0 ; j00 < mmq_x; j00 += 8 *nwarps) {
1217
+ const int j0 = j00 + 8 *threadIdx .y + 2 *(threadIdx .x %4 );
1218
+ #pragma unroll
1219
+ for (int i00 = 0 ; i00 < mmq_y; i00 += 16 ) {
1220
+ const int i0 = i00 + threadIdx .x /4 ;
1221
+
1222
+ #pragma unroll
1223
+ for (int l = 0 ; l < 4 ; ++l) {
1224
+ const int i = i0 + 8 *(l/2 );
1225
+ const int j = j0 + (l%2 );
1226
+
1227
+ const int row_dst = row_dst_0 + i;
1228
+ const int col_dst = col_dst_0 + j;
1229
+ if (row_dst >= nrows_dst) {
1230
+ continue ;
1231
+ }
1232
+ if (col_dst >= ncols_dst) {
1233
+ continue ;
1234
+ }
1235
+ dst[col_dst*nrows_dst + row_dst] = sum[j00/(8 *nwarps)][i00/16 ][l];
1236
+ }
1237
+ }
1238
+ }
1239
+ }
1240
+
1069
1241
template <int qk, int qr, int qi, bool need_sum, typename block_q_t , int mmq_x, int mmq_y, int nwarps,
1070
1242
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
1071
1243
static __device__ __forceinline__ void mul_mat_q (
@@ -1304,7 +1476,7 @@ template <bool need_check> static __global__ void
1304
1476
#if __CUDA_ARCH__ >= MIN_CC_DP4A
1305
1477
constexpr mmq_arch_config_t arch_config = get_arch_config_device (MMQ_CONFIG_Q8_0);
1306
1478
1307
- mul_mat_q <QK8_0, QR8_0, QI8_0, false , block_q8_0, arch_config.x , arch_config.y , arch_config.nwarps , allocate_tiles_q8_0<arch_config.y >,
1479
+ mul_mat_q_test <QK8_0, QR8_0, QI8_0, false , block_q8_0, arch_config.x , arch_config.y , arch_config.nwarps , allocate_tiles_q8_0<arch_config.y >,
1308
1480
load_tiles_q8_0<arch_config.y , arch_config.nwarps , need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
1309
1481
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
1310
1482
#else
0 commit comments