@@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439
439
struct ggml_tensor_extra_gpu {
440
440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441
441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442
- bool copied;
443
442
};
444
443
445
444
// this is faster on Windows
@@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4357
4356
4358
4357
// rope == RoPE == rotary positional embedding
4359
4358
4360
- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361
- const int p_delta_rows, const float theta_scale) {
4359
+ template <typename T, bool has_pos>
4360
+ static __global__ void rope (const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361
+ const int p_delta_rows, const float theta_scale) {
4362
4362
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4363
4363
4364
4364
if (col >= ncols) {
@@ -4369,7 +4369,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4369
4369
const int i = row*ncols + col;
4370
4370
const int i2 = row/p_delta_rows;
4371
4371
4372
- const int p = pos != nullptr ? pos[i2] : 0 ;
4372
+ const int p = has_pos ? pos[i2] : 0 ;
4373
4373
const float p0 = p * freq_scale;
4374
4374
const float theta = p0*powf (theta_scale, col/2 );
4375
4375
const float sin_theta = sinf (theta);
@@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4382
4382
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
4383
4383
}
4384
4384
4385
- static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4386
- const int p_delta_rows, const float theta_scale) {
4385
+ template <typename T, bool has_pos>
4386
+ static __global__ void rope_neox (const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4387
+ const int p_delta_rows, const float theta_scale) {
4387
4388
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4388
4389
4389
4390
if (col >= ncols) {
@@ -4394,7 +4395,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4394
4395
const int i = row*ncols + col/2 ;
4395
4396
const int i2 = row/p_delta_rows;
4396
4397
4397
- const int p = pos != nullptr ? pos[i2] : 0 ;
4398
+ const int p = has_pos ? pos[i2] : 0 ;
4398
4399
const float p0 = p * freq_scale;
4399
4400
const float theta = p0*powf (theta_scale, col/2 );
4400
4401
const float sin_theta = sinf (theta);
@@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
5371
5372
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
5372
5373
}
5373
5374
5374
- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375
+ template <typename T>
5376
+ static void rope_cuda (const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375
5377
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5376
5378
GGML_ASSERT (ncols % 2 == 0 );
5377
5379
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
5378
5380
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
5379
5381
const dim3 block_nums (nrows, num_blocks_x, 1 );
5380
- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5382
+ if (pos == nullptr ) {
5383
+ rope<T, false ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5384
+ } else {
5385
+ rope<T, true ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5386
+ }
5381
5387
}
5382
5388
5383
- static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5389
+ template <typename T>
5390
+ static void rope_neox_cuda (const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5384
5391
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5385
5392
GGML_ASSERT (ncols % 2 == 0 );
5386
5393
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
5387
5394
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
5388
5395
const dim3 block_nums (nrows, num_blocks_x, 1 );
5389
- rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5396
+ if (pos == nullptr ) {
5397
+ rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5398
+ } else {
5399
+ rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5400
+ }
5390
5401
}
5391
5402
5392
5403
static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
@@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6036
6047
const int64_t ne0 = dst->ne [0 ];
6037
6048
const int64_t row_diff = row_high - row_low;
6038
6049
6039
- float * src0_ddq_as_f32;
6050
+ float * src0_ddq_as_f32 = nullptr ;
6040
6051
size_t src0_as = 0 ;
6041
6052
6042
6053
if (src0->type != GGML_TYPE_F32) {
@@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope(
6074
6085
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6075
6086
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6076
6087
6077
- GGML_ASSERT (src0->type == GGML_TYPE_F32);
6078
- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6088
+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
6089
+ GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
6090
+ GGML_ASSERT (src0->type == dst->type );
6079
6091
6080
6092
const int64_t ne00 = src0->ne [0 ];
6081
6093
const int64_t ne01 = src0->ne [1 ];
@@ -6093,23 +6105,16 @@ inline void ggml_cuda_op_rope(
6093
6105
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
6094
6106
6095
6107
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
6096
- // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6097
6108
6098
- GGML_ASSERT (src1->type == GGML_TYPE_I32);
6099
- GGML_ASSERT (src1->ne [0 ] == ne2);
6100
- GGML_ASSERT (src1->backend == GGML_BACKEND_GPU);
6101
-
6102
- int id;
6103
- CUDA_CHECK (cudaGetDevice (&id));
6104
-
6105
- int * pos = nullptr ;
6109
+ int32_t * pos = nullptr ;
6106
6110
if ((mode & 1 ) == 0 ) {
6111
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
6112
+ GGML_ASSERT (src1->ne [0 ] == ne2);
6113
+ GGML_ASSERT (src1->backend == GGML_BACKEND_GPU);
6107
6114
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra ;
6108
- pos = (int *) src1_extra->data_device [id];
6109
- if (!src1_extra->copied ) {
6110
- CUDA_CHECK (cudaMemcpyAsync (pos, src1->data , ggml_nbytes (src1), cudaMemcpyHostToDevice, main_stream));
6111
- src1_extra->copied = true ;
6112
- }
6115
+ int id;
6116
+ CUDA_CHECK (cudaGetDevice (&id));
6117
+ pos = (int32_t *) src1_extra->data_device [id];
6113
6118
}
6114
6119
6115
6120
const bool is_neox = mode & 2 ;
@@ -6121,9 +6126,21 @@ inline void ggml_cuda_op_rope(
6121
6126
rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6122
6127
} else if (is_neox) {
6123
6128
GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
6124
- rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6129
+ if (src0->type == GGML_TYPE_F32) {
6130
+ rope_neox_cuda ((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6131
+ } else if (src0->type == GGML_TYPE_F16) {
6132
+ rope_neox_cuda ((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6133
+ } else {
6134
+ GGML_ASSERT (false );
6135
+ }
6125
6136
} else {
6126
- rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6137
+ if (src0->type == GGML_TYPE_F32) {
6138
+ rope_cuda ((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6139
+ } else if (src0->type == GGML_TYPE_F16) {
6140
+ rope_cuda ((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6141
+ } else {
6142
+ GGML_ASSERT (false );
6143
+ }
6127
6144
}
6128
6145
6129
6146
(void ) src1;
@@ -6294,7 +6311,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
6294
6311
}
6295
6312
}
6296
6313
6297
- void ggml_cuda_set_peer_access (const int n_tokens) {
6314
+ static void ggml_cuda_set_peer_access (const int n_tokens) {
6298
6315
static bool peer_access_enabled = false ;
6299
6316
6300
6317
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@@ -6622,27 +6639,27 @@ static void ggml_cuda_op_mul_mat(
6622
6639
}
6623
6640
}
6624
6641
6625
- void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6642
+ static void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6626
6643
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_add);
6627
6644
}
6628
6645
6629
- void ggml_cuda_mul (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6646
+ static void ggml_cuda_mul (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6630
6647
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_mul);
6631
6648
}
6632
6649
6633
- void ggml_cuda_gelu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6650
+ static void ggml_cuda_gelu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6634
6651
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_gelu);
6635
6652
}
6636
6653
6637
- void ggml_cuda_silu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6654
+ static void ggml_cuda_silu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6638
6655
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_silu);
6639
6656
}
6640
6657
6641
- void ggml_cuda_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6658
+ static void ggml_cuda_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6642
6659
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_norm);
6643
6660
}
6644
6661
6645
- void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6662
+ static void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6646
6663
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_rms_norm);
6647
6664
}
6648
6665
@@ -6663,7 +6680,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
6663
6680
return false ;
6664
6681
}
6665
6682
6666
- void ggml_cuda_mul_mat_vec_p021 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6683
+ static void ggml_cuda_mul_mat_vec_p021 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6667
6684
GGML_ASSERT (ggml_is_permuted (src0) && ggml_is_permuted (src1));
6668
6685
GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
6669
6686
GGML_ASSERT (src0->nb [0 ] <= src0->nb [1 ] && src0->nb [2 ] <= src0->nb [3 ]); // 0213 permutation
@@ -6692,7 +6709,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
6692
6709
ggml_mul_mat_p021_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
6693
6710
}
6694
6711
6695
- void ggml_cuda_mul_mat_vec_nc (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6712
+ static void ggml_cuda_mul_mat_vec_nc (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6696
6713
GGML_ASSERT (!ggml_is_contiguous (src0) && ggml_is_contiguous (src1));
6697
6714
GGML_ASSERT (!ggml_is_permuted (src0));
6698
6715
GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6726,7 +6743,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
6726
6743
ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
6727
6744
}
6728
6745
6729
- void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6746
+ static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6730
6747
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
6731
6748
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
6732
6749
@@ -6770,11 +6787,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
6770
6787
}
6771
6788
}
6772
6789
6773
- void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6790
+ static void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6774
6791
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_scale);
6775
6792
}
6776
6793
6777
- void ggml_cuda_cpy (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6794
+ static void ggml_cuda_cpy (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6778
6795
const int64_t ne = ggml_nelements (src0);
6779
6796
GGML_ASSERT (ne == ggml_nelements (src1));
6780
6797
@@ -6822,29 +6839,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
6822
6839
(void ) dst;
6823
6840
}
6824
6841
6825
- void ggml_cuda_dup (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6842
+ static void ggml_cuda_dup (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6826
6843
ggml_cuda_cpy (src0, dst, nullptr );
6827
6844
(void ) src1;
6828
6845
}
6829
6846
6830
- void ggml_cuda_diag_mask_inf (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6847
+ static void ggml_cuda_diag_mask_inf (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6831
6848
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_diag_mask_inf);
6832
6849
}
6833
6850
6834
- void ggml_cuda_soft_max (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6851
+ static void ggml_cuda_soft_max (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6835
6852
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_soft_max);
6836
6853
}
6837
6854
6838
- void ggml_cuda_rope (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6855
+ static void ggml_cuda_rope (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6839
6856
GGML_ASSERT (ggml_is_contiguous (src0)); // TODO: this restriction is temporary until non-cont support is implemented
6840
6857
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_rope);
6841
6858
}
6842
6859
6843
- void ggml_cuda_alibi (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6860
+ static void ggml_cuda_alibi (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6844
6861
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_alibi);
6845
6862
}
6846
6863
6847
- void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6864
+ static void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6848
6865
(void ) src0;
6849
6866
(void ) src1;
6850
6867
(void ) dst;
@@ -6967,11 +6984,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
6967
6984
return extra;
6968
6985
}
6969
6986
6970
- void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6987
+ static void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6971
6988
if (scratch && g_scratch_size == 0 ) {
6972
6989
return ;
6973
6990
}
6974
6991
6992
+ tensor->backend = GGML_BACKEND_GPU;
6993
+
6975
6994
// recursively assign CUDA buffers until a compute tensor is found
6976
6995
if (tensor->src [0 ] != nullptr && tensor->src [0 ]->backend == GGML_BACKEND_CPU) {
6977
6996
const ggml_op src0_op = tensor->src [0 ]->op ;
@@ -6983,8 +7002,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
6983
7002
ggml_cuda_assign_buffers_impl (tensor->src [1 ], scratch, force_inplace, no_alloc);
6984
7003
}
6985
7004
6986
- tensor->backend = GGML_BACKEND_GPU;
6987
-
6988
7005
if (scratch && no_alloc) {
6989
7006
return ;
6990
7007
}
@@ -7069,6 +7086,16 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
7069
7086
tensor->extra = extra;
7070
7087
}
7071
7088
7089
+ void ggml_cuda_copy_to_device (struct ggml_tensor * tensor) {
7090
+ GGML_ASSERT (tensor->backend == GGML_BACKEND_GPU);
7091
+ GGML_ASSERT (ggml_is_contiguous (tensor));
7092
+
7093
+ struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra ;
7094
+ CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7095
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
7096
+ CUDA_CHECK (cudaMemcpyAsync (extra->data_device [g_main_device], tensor->data , ggml_nbytes (tensor), cudaMemcpyHostToDevice, main_stream));
7097
+ }
7098
+
7072
7099
void ggml_cuda_assign_buffers (struct ggml_tensor * tensor) {
7073
7100
ggml_cuda_assign_buffers_impl (tensor, true , false , false );
7074
7101
}
0 commit comments