150
150
#define CUDA_USE_TENSOR_CORES
151
151
#endif
152
152
153
- // max batch size to use MMQ kernels when tensor cores are available
154
- #define MMQ_MAX_BATCH_SIZE 32
153
+ # define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
154
+ #define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
155
155
156
156
#if defined(GGML_USE_HIPBLAS)
157
157
#define __CUDA_ARCH__ 1300
@@ -5310,51 +5310,59 @@ template <bool need_check> static __global__ void
5310
5310
#endif // __CUDA_ARCH__ >= CC_VOLTA
5311
5311
}
5312
5312
5313
- #define MMVQ_NWARPS_NVIDIA 4
5314
- #define MMVQ_NWARPS_AMD_RDNA2 1
5315
- #define MMVQ_NWARPS_AMD_OLD 4
5316
-
5317
- template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5313
+ template <int ncols_y, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5318
5314
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5319
- __launch_bounds__ (nwarps*WARP_SIZE, 1 ) // tells the compiler to use as many registers as it wants
5315
+ // tell the compiler to use as many registers as it wants, see nwarps definition below
5316
+ __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
5320
5317
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5321
5318
static __global__ void mul_mat_vec_q (
5322
5319
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5323
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5320
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
5324
5321
5325
- const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5322
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
5323
+ constexpr int nwarps = 1 ;
5324
+ constexpr int rows_per_cuda_block = 1 ;
5325
+ #else
5326
+ constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
5327
+ constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
5328
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
5326
5329
5327
- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
5328
- const int row = blockIdx .x ;
5329
-
5330
- const int blocks_per_row_x = ncols_x / qk;
5331
- const int blocks_per_col_y = nrows_y / QK8_1;
5332
- const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
5330
+ const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
5331
+ const int row0 = rows_per_cuda_block*blockIdx .x ;
5332
+ const int blocks_per_row_x = ncols_x / qk;
5333
+ const int blocks_per_col_y = nrows_y / QK8_1;
5334
+ constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
5333
5335
5334
5336
// partial sum for each thread
5335
- float tmp[ncols_y_template != 0 ? ncols_y_template : 8 ] = {0 .0f };
5337
+ float tmp[ncols_y][rows_per_cuda_block ] = {0 .0f };
5336
5338
5337
5339
const block_q_t * x = (const block_q_t *) vx;
5338
5340
const block_q8_1 * y = (const block_q8_1 *) vy;
5339
5341
5340
- for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
5341
- const int ibx = row*blocks_per_row_x + i; // x block index
5342
-
5343
- const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5342
+ for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
5343
+ const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
5344
5344
5345
- const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
5345
+ // x block quant index when casting the quants to int
5346
+ const int kqs = vdr * (tid % (qi/vdr));
5346
5347
5347
5348
#pragma unroll
5348
5349
for (int j = 0 ; j < ncols_y; ++j) {
5349
- tmp[j] += vec_dot_q_cuda (&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
5350
+ #pragma unroll
5351
+ for (int i = 0 ; i < rows_per_cuda_block; ++i) {
5352
+ tmp[j][i] += vec_dot_q_cuda (
5353
+ &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
5354
+ }
5350
5355
}
5351
5356
}
5352
5357
5353
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y_template != 0 ? ncols_y_template : 8 ][WARP_SIZE];
5358
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block ][WARP_SIZE];
5354
5359
if (threadIdx .y > 0 ) {
5355
5360
#pragma unroll
5356
5361
for (int j = 0 ; j < ncols_y; ++j) {
5357
- tmp_shared[threadIdx .y -1 ][j][threadIdx .x ] = tmp[j];
5362
+ #pragma unroll
5363
+ for (int i = 0 ; i < rows_per_cuda_block; ++i) {
5364
+ tmp_shared[threadIdx .y -1 ][j][i][threadIdx .x ] = tmp[j][i];
5365
+ }
5358
5366
}
5359
5367
}
5360
5368
__syncthreads ();
@@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q(
5366
5374
#pragma unroll
5367
5375
for (int j = 0 ; j < ncols_y; ++j) {
5368
5376
#pragma unroll
5369
- for (int i = 0 ; i < nwarps-1 ; ++i) {
5370
- tmp[j] += tmp_shared[i][j][threadIdx .x ];
5377
+ for (int i = 0 ; i < rows_per_cuda_block; ++i) {
5378
+ #pragma unroll
5379
+ for (int l = 0 ; l < nwarps-1 ; ++l) {
5380
+ tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
5381
+ }
5382
+ tmp[j][i] = warp_reduce_sum (tmp[j][i]);
5371
5383
}
5372
- tmp[j] = warp_reduce_sum (tmp[j]);
5373
5384
5374
- if (threadIdx .x == 0 ) {
5375
- dst[j*nrows_dst + row ] = tmp[j];
5385
+ if (threadIdx .x < rows_per_cuda_block ) {
5386
+ dst[j*nrows_dst + row0 + threadIdx . x ] = tmp[j][ threadIdx . x ];
5376
5387
}
5377
5388
}
5378
5389
}
@@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda(
6851
6862
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
6852
6863
6853
6864
GGML_ASSERT (ncols_x % qk == 0 );
6854
- GGML_ASSERT (ncols_y <= 4 );
6865
+ GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE );
6855
6866
6856
6867
int id;
6857
6868
CUDA_CHECK (cudaGetDevice (&id));
6858
6869
6859
- int nwarps;
6860
- if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861
- nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862
- } else {
6863
- nwarps = MMVQ_NWARPS_NVIDIA;
6864
- }
6870
+ int64_t nwarps = 1 ;
6871
+ int64_t rows_per_cuda_block = 1 ;
6865
6872
6866
- const dim3 block_nums (nrows_x, 1 , 1 );
6867
- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
6868
-
6869
- switch (nwarps) {
6870
- case 1 : switch (ncols_y) {
6873
+ if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
6874
+ switch (ncols_y) {
6871
6875
case 1 :
6872
- mul_mat_vec_q< 1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6873
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst) ;
6876
+ nwarps = 4 ;
6877
+ rows_per_cuda_block = 1 ;
6874
6878
break ;
6875
6879
case 2 :
6876
- mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6877
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6878
- break ;
6879
6880
case 3 :
6880
- mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6881
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882
- break ;
6883
6881
case 4 :
6884
- mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6885
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6886
- break ;
6887
- default :
6888
- GGML_ASSERT (false );
6889
- break ;
6890
- } break ;
6891
- case 4 : switch (ncols_y) {
6892
- case 1 :
6893
- mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6894
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882
+ nwarps = 4 ;
6883
+ rows_per_cuda_block = 2 ;
6895
6884
break ;
6896
- case 2 :
6897
- mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6898
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899
- break ;
6900
- case 3 :
6901
- mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6902
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903
- break ;
6904
- case 4 :
6905
- mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6906
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6885
+ case 5 :
6886
+ case 6 :
6887
+ case 7 :
6888
+ case 8 :
6889
+ nwarps = 2 ;
6890
+ rows_per_cuda_block = 2 ;
6907
6891
break ;
6908
6892
default :
6909
6893
GGML_ASSERT (false );
6910
6894
break ;
6911
- } break ;
6895
+ }
6896
+ }
6897
+ const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
6898
+ const dim3 block_nums (nblocks, 1 , 1 );
6899
+ const dim3 block_dims (WARP_SIZE, nwarps, 1 );
6912
6900
6901
+ switch (ncols_y) {
6902
+ case 1 :
6903
+ mul_mat_vec_q<1 , qk, qi, block_q_t , vdr, vec_dot>
6904
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6905
+ break ;
6906
+ case 2 :
6907
+ mul_mat_vec_q<2 , qk, qi, block_q_t , vdr, vec_dot>
6908
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6909
+ break ;
6910
+ case 3 :
6911
+ mul_mat_vec_q<3 , qk, qi, block_q_t , vdr, vec_dot>
6912
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6913
+ break ;
6914
+ case 4 :
6915
+ mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
6916
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6917
+ break ;
6918
+ case 5 :
6919
+ mul_mat_vec_q<5 , qk, qi, block_q_t , vdr, vec_dot>
6920
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6921
+ break ;
6922
+ case 6 :
6923
+ mul_mat_vec_q<6 , qk, qi, block_q_t , vdr, vec_dot>
6924
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6925
+ break ;
6926
+ case 7 :
6927
+ mul_mat_vec_q<7 , qk, qi, block_q_t , vdr, vec_dot>
6928
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6929
+ break ;
6930
+ case 8 :
6931
+ mul_mat_vec_q<8 , qk, qi, block_q_t , vdr, vec_dot>
6932
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6933
+ break ;
6913
6934
default :
6914
6935
GGML_ASSERT (false );
6915
6936
break ;
@@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs(
9735
9756
ptrs_dst[0 *ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
9736
9757
}
9737
9758
9738
- static void ggml_cuda_mul_mat_mat_batched_cublas (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9759
+ static void ggml_cuda_mul_mat_batched_cublas (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9739
9760
GGML_ASSERT (!ggml_is_transposed (src0));
9740
9761
GGML_ASSERT (!ggml_is_transposed (src1));
9741
9762
@@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
9893
9914
9894
9915
int64_t min_compute_capability = INT_MAX;
9895
9916
9917
+ bool any_pascal_with_slow_fp16 = false ;
9896
9918
if (split) {
9897
9919
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer ->buft ->context ;
9898
9920
auto & tensor_split = buft_ctx->tensor_split ;
9899
9921
for (int id = 0 ; id < g_device_count; ++id) {
9900
- if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1 ] : 1 .0f )) {
9922
+ // skip devices that are not going to do any work:
9923
+ if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1 ] : 1 .0f )) {
9924
+ continue ;
9925
+ }
9926
+
9927
+ if (min_compute_capability > g_device_caps[id].cc ) {
9901
9928
min_compute_capability = g_device_caps[id].cc ;
9902
9929
}
9930
+ if (g_device_caps[id].cc == 610 ) {
9931
+ any_pascal_with_slow_fp16 = true ;
9932
+ }
9903
9933
}
9904
9934
} else {
9905
- min_compute_capability = g_device_caps[g_main_device].cc ;
9935
+ min_compute_capability = g_device_caps[g_main_device].cc ;
9936
+ any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610 ;
9906
9937
}
9907
9938
9939
+ // check data types and tensor shapes for custom matrix multiplication kernels:
9940
+ bool use_dequantize_mul_mat_vec = (ggml_is_quantized (src0->type ) || src0->type == GGML_TYPE_F16)
9941
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
9942
+ && src0->ne [0 ] % GGML_CUDA_DMMV_X == 0 && src1->ne [1 ] == 1 ;
9943
+
9944
+ bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
9945
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
9946
+ && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
9947
+
9948
+ bool use_mul_mat_q = ggml_cuda_supports_mmq (src0->type )
9949
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
9950
+
9908
9951
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
9909
9952
9910
9953
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
9911
- bool use_mul_mat_q = ggml_is_quantized (src0-> type );
9954
+
9912
9955
#ifdef CUDA_USE_TENSOR_CORES
9913
9956
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
9914
9957
#endif // CUDA_USE_TENSOR_CORES
9915
9958
9916
9959
#else
9917
9960
9918
- const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
9919
- bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type );
9961
+ // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
9962
+ const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
9963
+
9964
+ // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
9965
+ use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
9966
+ use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
9967
+
9920
9968
#ifdef CUDA_USE_TENSOR_CORES
9921
9969
// when tensor cores are available, use them for large batch size
9922
9970
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
9923
- use_mul_mat_q = use_mul_mat_q && !( fp16_performance_good && src1->ne [1 ] > MMQ_MAX_BATCH_SIZE);
9971
+ use_mul_mat_q = use_mul_mat_q && (! fp16_performance_good || src1->ne [1 ] <= MMQ_MAX_BATCH_SIZE);
9924
9972
#endif // CUDA_USE_TENSOR_CORES
9925
9973
9926
9974
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
9927
9975
9928
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq (src0->type );
9976
+ // if mmvq is available it's a better choice than dmmv:
9977
+ #ifndef GGML_CUDA_FORCE_DMMV
9978
+ use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
9979
+ #endif // GGML_CUDA_FORCE_DMMV
9929
9980
9930
9981
// debug helpers
9931
9982
// printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
@@ -9943,33 +9994,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
9943
9994
ggml_cuda_mul_mat_vec_nc (src0, src1, dst);
9944
9995
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
9945
9996
// KQ + KQV multi-batch
9946
- ggml_cuda_mul_mat_mat_batched_cublas (src0, src1, dst);
9947
- } else if (src0->type == GGML_TYPE_F32) {
9948
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false );
9949
- } else if (ggml_is_quantized (src0->type ) || src0->type == GGML_TYPE_F16) {
9950
- if (src1->ne [1 ] == 1 && src0->ne [0 ] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
9951
- #ifdef GGML_CUDA_FORCE_DMMV
9952
- const bool use_mul_mat_vec_q = false ;
9953
- #else
9954
- const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type );
9955
- #endif // GGML_CUDA_FORCE_DMMV
9956
-
9957
- if (use_mul_mat_vec_q) {
9958
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
9959
- } else {
9960
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false );
9961
- }
9962
- } else {
9963
- if (src1->ne [1 ] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type ) && src1->type == GGML_TYPE_F32) {
9964
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
9965
- } else if (use_mul_mat_q) {
9966
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_q, true );
9967
- } else {
9968
- ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false );
9969
- }
9970
- }
9997
+ ggml_cuda_mul_mat_batched_cublas (src0, src1, dst);
9998
+ } else if (use_dequantize_mul_mat_vec) {
9999
+ ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false );
10000
+ } else if (use_mul_mat_vec_q) {
10001
+ ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
10002
+ } else if (use_mul_mat_q) {
10003
+ ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_q, true );
9971
10004
} else {
9972
- GGML_ASSERT ( false );
10005
+ ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false );
9973
10006
}
9974
10007
}
9975
10008
0 commit comments