@@ -173,6 +173,52 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
173
173
}
174
174
}
175
175
176
+ template <int block_size> static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
177
+ const block_q4_0 * x = (const block_q4_0 *) vx;
178
+ const int qk = QK4_0;
179
+
180
+ const int row = blockIdx .x ;
181
+ const int tid = threadIdx .x ;
182
+
183
+ __shared__ float tmp[block_size]; // separate sum for each thread
184
+ tmp[tid] = 0 ;
185
+
186
+ for (int i = 0 ; i < ncols/block_size; i += 2 ) {
187
+ const int col = i*block_size + 2 *tid;
188
+ const int ib = (row*ncols + col)/qk; // block index
189
+ const int iqs = (col%qk)/2 ; // quant index
190
+ const int iybs = col - col%qk; // y block start index
191
+
192
+ // dequantize
193
+ const float d = x[ib].d ;
194
+
195
+ const uint8_t * pp = x[ib].qs ;
196
+
197
+ const uint8_t vui = pp[iqs];
198
+
199
+ const int8_t vi0 = vui & 0xF ;
200
+ const int8_t vi1 = vui >> 4 ;
201
+
202
+ const float v0 = (vi0 - 8 )*d;
203
+ const float v1 = (vi1 - 8 )*d;
204
+
205
+ // matrix multiplication
206
+ tmp[tid] += v0 * y[iybs + iqs + 0 ];
207
+ tmp[tid] += v1 * y[iybs + iqs + qk/2 ];
208
+ }
209
+
210
+ // sum up partial sums and write back result
211
+ for (int s=block_size/2 ; s>0 ; s>>=1 ) {
212
+ if (tid < s) {
213
+ tmp[tid] += tmp[tid + s];
214
+ }
215
+ __syncthreads ();
216
+ }
217
+ if (tid == 0 ) {
218
+ dst[row] = tmp[0 ];
219
+ }
220
+ }
221
+
176
222
static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
177
223
const int nb = k / QK4_0;
178
224
dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -198,6 +244,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
198
244
dequantize_block_q8_0<<<nb, 1 , 0 , stream>>> (vx, y);
199
245
}
200
246
247
+ static void dequantize_mul_mat_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
248
+ // static int block_size = -1;
249
+ // if (block_size == -1) {
250
+ // int min_grid_size, max_block_size = 1;
251
+ // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
252
+ // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
253
+ // block_size = 1;
254
+ // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
255
+ // block_size *= 2;
256
+ // }
257
+ // }
258
+ // dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
259
+ const int block_size = 32 ;
260
+ GGML_ASSERT (ncols % block_size == 0 );
261
+ dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
262
+ }
263
+
201
264
// TODO: optimize
202
265
static __global__ void convert_fp16_to_fp32 (const void * vx, float * y) {
203
266
const half * x = (const half *) vx;
@@ -231,7 +294,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
231
294
}
232
295
233
296
// buffer pool for cuda
234
- #define MAX_CUDA_BUFFERS 16
297
+ #define MAX_CUDA_BUFFERS 256
235
298
236
299
struct scoped_spin_lock {
237
300
std::atomic_flag& lock;
@@ -538,7 +601,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
538
601
const size_t q_sz = ggml_type_size (type) * x_ne / ggml_blck_size (type);
539
602
540
603
size_t x_size, y_size, d_size, q_size;
541
- float * d_X = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * x_ne, &x_size);
604
+ float * d_X;
605
+ if (ne11 > 1 ) {
606
+ d_X = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * x_ne, &x_size);
607
+ }
542
608
float * d_Y = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * y_ne, &y_size);
543
609
float * d_D = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * d_ne, &d_size);
544
610
char * d_Q = (char *) ggml_cuda_pool_malloc (n_mm * q_sz, &q_size);
@@ -553,31 +619,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
553
619
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
554
620
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
555
621
556
- float * c_X = d_X + i * x_ne;
557
622
float * c_Y = d_Y + i * y_ne;
558
623
float * c_D = d_D + i * d_ne;
559
624
char * c_Q = d_Q + i * q_sz;
560
625
561
- // copy src0 and convert to fp32 on device
562
- CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Q, src0, i03, i02, cudaStream2));
563
- to_fp32_cuda (c_Q, c_X, x_ne, cudaStream2);
564
- CUDA_CHECK (cudaGetLastError ());
565
- CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
626
+ // copy src0 to device if necessary
627
+ if (src0->backend == GGML_BACKEND_CPU) {
628
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Q, src0, i03, i02, cudaStream2));
629
+ } else if (src0->backend == GGML_BACKEND_CUDA) {
630
+ c_Q = ((char *) src0->data ) + i * q_sz;
631
+ } else {
632
+ GGML_ASSERT (false );
633
+ }
634
+ if (ne11 == 1 ) {
635
+ CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
566
636
567
- // copy src1 to device
568
- CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Y, src1, i03, i02, cudaStream));
637
+ // copy src1 to device
638
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Y, src1, i03, i02, cudaStream));
569
639
570
- // wait for conversion
571
- CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
640
+ // wait for data
641
+ CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
572
642
573
- // compute
574
- CUBLAS_CHECK (cublasSetStream (g_cublasH, cudaStream));
575
- CUBLAS_CHECK (
576
- cublasSgemm (g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
577
- ne01, ne11, ne10,
578
- &alpha, c_X, ne00,
579
- c_Y, ne10,
580
- &beta, c_D, ne01));
643
+ // compute
644
+ dequantize_mul_mat_q4_0_cuda (c_Q, c_Y, c_D, ne00, ne01, cudaStream);
645
+ CUDA_CHECK (cudaGetLastError ());
646
+
647
+ } else {
648
+ float * c_X = d_X + i * x_ne;
649
+
650
+ // convert src0 to fp32 on device
651
+ to_fp32_cuda (c_Q, c_X, x_ne, cudaStream2);
652
+ CUDA_CHECK (cudaGetLastError ());
653
+ CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
654
+
655
+ // copy src1 to device
656
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Y, src1, i03, i02, cudaStream));
657
+
658
+ // wait for conversion
659
+ CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
660
+
661
+ // compute
662
+ CUBLAS_CHECK (cublasSetStream (g_cublasH, cudaStream));
663
+ CUBLAS_CHECK (
664
+ cublasSgemm (g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
665
+ ne01, ne11, ne10,
666
+ &alpha, c_X, ne00,
667
+ c_Y, ne10,
668
+ &beta, c_D, ne01));
669
+ }
581
670
582
671
// copy dst to host
583
672
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@@ -586,7 +675,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
586
675
}
587
676
588
677
CUDA_CHECK (cudaDeviceSynchronize ());
589
- ggml_cuda_pool_free (d_X, x_size);
678
+ if (ne11 > 1 ) {
679
+ ggml_cuda_pool_free (d_X, x_size);
680
+ }
590
681
ggml_cuda_pool_free (d_Y, y_size);
591
682
ggml_cuda_pool_free (d_D, d_size);
592
683
ggml_cuda_pool_free (d_Q, q_size);
@@ -602,8 +693,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
602
693
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) &&
603
694
src1->type == GGML_TYPE_F32 &&
604
695
dst->type == GGML_TYPE_F32 &&
605
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32 )) {
606
-
696
+ ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32 ) || src0->backend == GGML_BACKEND_CUDA)) {
607
697
return true ;
608
698
}
609
699
@@ -655,3 +745,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
655
745
return 0 ;
656
746
}
657
747
}
748
+
749
+ void ggml_cuda_transform_tensor (ggml_tensor * tensor) {
750
+ const int64_t ne0 = tensor->ne [0 ];
751
+ const int64_t ne1 = tensor->ne [1 ];
752
+ const int64_t ne2 = tensor->ne [2 ];
753
+ const int64_t ne3 = tensor->ne [3 ];
754
+
755
+ const ggml_type type = tensor->type ;
756
+ const size_t q_sz = ggml_type_size (type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size (type);
757
+
758
+ size_t q_size;
759
+ char * d_Q = (char *) ggml_cuda_pool_malloc (q_sz, &q_size);
760
+
761
+ cudaStream_t cudaStream2 = g_cudaStreams2[0 ];
762
+
763
+ // copy tensor to device
764
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (d_Q, tensor, 0 , 0 , cudaStream2));
765
+ CUDA_CHECK (cudaDeviceSynchronize ());
766
+
767
+ tensor->data = d_Q;
768
+ tensor->backend = GGML_BACKEND_CUDA;
769
+ }
0 commit comments