Skip to content

Commit e04dc51

Browse files
slarenggerganov
andauthored
ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)
* ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent db0fc2d commit e04dc51

File tree

4 files changed

+110
-67
lines changed

4 files changed

+110
-67
lines changed

ggml-cuda.cu

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439439
struct ggml_tensor_extra_gpu {
440440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442-
bool copied;
443442
};
444443

445444
// this is faster on Windows
@@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43574356

43584357
// rope == RoPE == rotary positional embedding
43594358

4360-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361-
const int p_delta_rows, const float theta_scale) {
4359+
template<typename T, bool has_pos>
4360+
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361+
const int p_delta_rows, const float theta_scale) {
43624362
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43634363

43644364
if (col >= ncols) {
@@ -4369,8 +4369,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43694369
const int i = row*ncols + col;
43704370
const int i2 = row/p_delta_rows;
43714371

4372-
const int p = pos != nullptr ? pos[i2] : 0;
4373-
const float p0 = p * freq_scale;
4372+
const int p = has_pos ? pos[i2] : 0;
4373+
const float p0 = p*freq_scale;
43744374
const float theta = p0*powf(theta_scale, col/2);
43754375
const float sin_theta = sinf(theta);
43764376
const float cos_theta = cosf(theta);
@@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43824382
dst[i + 1] = x0*sin_theta + x1*cos_theta;
43834383
}
43844384

4385-
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4386-
const int p_delta_rows, const float theta_scale) {
4385+
template<typename T, bool has_pos>
4386+
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4387+
const int p_delta_rows, const float theta_scale) {
43874388
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43884389

43894390
if (col >= ncols) {
@@ -4394,8 +4395,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43944395
const int i = row*ncols + col/2;
43954396
const int i2 = row/p_delta_rows;
43964397

4397-
const int p = pos != nullptr ? pos[i2] : 0;
4398-
const float p0 = p * freq_scale;
4398+
const int p = has_pos ? pos[i2] : 0;
4399+
const float p0 = p*freq_scale;
43994400
const float theta = p0*powf(theta_scale, col/2);
44004401
const float sin_theta = sinf(theta);
44014402
const float cos_theta = cosf(theta);
@@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
53715372
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
53725373
}
53735374

5374-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375+
template<typename T>
5376+
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53755377
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53765378
GGML_ASSERT(ncols % 2 == 0);
53775379
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53785380
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53795381
const dim3 block_nums(nrows, num_blocks_x, 1);
5380-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5382+
if (pos == nullptr) {
5383+
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5384+
} else {
5385+
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5386+
}
53815387
}
53825388

5383-
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5389+
template<typename T>
5390+
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53845391
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53855392
GGML_ASSERT(ncols % 2 == 0);
53865393
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53875394
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53885395
const dim3 block_nums(nrows, num_blocks_x, 1);
5389-
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5396+
if (pos == nullptr) {
5397+
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5398+
} else {
5399+
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5400+
}
53905401
}
53915402

53925403
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
@@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
60366047
const int64_t ne0 = dst->ne[0];
60376048
const int64_t row_diff = row_high - row_low;
60386049

6039-
float * src0_ddq_as_f32;
6050+
float * src0_ddq_as_f32 = nullptr;
60406051
size_t src0_as = 0;
60416052

60426053
if (src0->type != GGML_TYPE_F32) {
@@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope(
60746085
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
60756086
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
60766087

6077-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6078-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
6088+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
6089+
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
6090+
GGML_ASSERT(src0->type == dst->type);
60796091

60806092
const int64_t ne00 = src0->ne[0];
60816093
const int64_t ne01 = src0->ne[1];
@@ -6093,23 +6105,12 @@ inline void ggml_cuda_op_rope(
60936105
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
60946106

60956107
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6096-
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6097-
6098-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6099-
GGML_ASSERT(src1->ne[0] == ne2);
6100-
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
61016108

6102-
int id;
6103-
CUDA_CHECK(cudaGetDevice(&id));
6104-
6105-
int * pos = nullptr;
6109+
const int32_t * pos = nullptr;
61066110
if ((mode & 1) == 0) {
6107-
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
6108-
pos = (int *) src1_extra->data_device[id];
6109-
if (!src1_extra->copied) {
6110-
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
6111-
src1_extra->copied = true;
6112-
}
6111+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6112+
GGML_ASSERT(src1->ne[0] == ne2);
6113+
pos = (const int32_t *) src1_dd;
61136114
}
61146115

61156116
const bool is_neox = mode & 2;
@@ -6121,9 +6122,21 @@ inline void ggml_cuda_op_rope(
61216122
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
61226123
} else if (is_neox) {
61236124
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6124-
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6125+
if (src0->type == GGML_TYPE_F32) {
6126+
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6127+
} else if (src0->type == GGML_TYPE_F16) {
6128+
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6129+
} else {
6130+
GGML_ASSERT(false);
6131+
}
61256132
} else {
6126-
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6133+
if (src0->type == GGML_TYPE_F32) {
6134+
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6135+
} else if (src0->type == GGML_TYPE_F16) {
6136+
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6137+
} else {
6138+
GGML_ASSERT(false);
6139+
}
61276140
}
61286141

61296142
(void) src1;
@@ -6294,7 +6307,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
62946307
}
62956308
}
62966309

6297-
void ggml_cuda_set_peer_access(const int n_tokens) {
6310+
static void ggml_cuda_set_peer_access(const int n_tokens) {
62986311
static bool peer_access_enabled = false;
62996312

63006313
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@@ -6622,27 +6635,27 @@ static void ggml_cuda_op_mul_mat(
66226635
}
66236636
}
66246637

6625-
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6638+
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66266639
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
66276640
}
66286641

6629-
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6642+
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66306643
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
66316644
}
66326645

6633-
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6646+
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66346647
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
66356648
}
66366649

6637-
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6650+
static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66386651
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
66396652
}
66406653

6641-
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6654+
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66426655
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
66436656
}
66446657

6645-
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6658+
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66466659
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
66476660
}
66486661

@@ -6663,7 +6676,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
66636676
return false;
66646677
}
66656678

6666-
void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6679+
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66676680
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
66686681
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
66696682
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@@ -6692,7 +6705,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
66926705
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
66936706
}
66946707

6695-
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6708+
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66966709
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
66976710
GGML_ASSERT(!ggml_is_permuted(src0));
66986711
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6726,7 +6739,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
67266739
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
67276740
}
67286741

6729-
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6742+
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67306743
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
67316744
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
67326745

@@ -6770,11 +6783,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
67706783
}
67716784
}
67726785

6773-
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6786+
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67746787
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
67756788
}
67766789

6777-
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6790+
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67786791
const int64_t ne = ggml_nelements(src0);
67796792
GGML_ASSERT(ne == ggml_nelements(src1));
67806793

@@ -6822,29 +6835,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
68226835
(void) dst;
68236836
}
68246837

6825-
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6838+
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68266839
ggml_cuda_cpy(src0, dst, nullptr);
68276840
(void) src1;
68286841
}
68296842

6830-
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6843+
static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68316844
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
68326845
}
68336846

6834-
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6847+
static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68356848
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
68366849
}
68376850

6838-
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6851+
static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68396852
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
68406853
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
68416854
}
68426855

6843-
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6856+
static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68446857
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
68456858
}
68466859

6847-
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6860+
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68486861
(void) src0;
68496862
(void) src1;
68506863
(void) dst;
@@ -6967,11 +6980,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
69676980
return extra;
69686981
}
69696982

6970-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6983+
static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
69716984
if (scratch && g_scratch_size == 0) {
69726985
return;
69736986
}
69746987

6988+
tensor->backend = GGML_BACKEND_GPU;
6989+
69756990
// recursively assign CUDA buffers until a compute tensor is found
69766991
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
69776992
const ggml_op src0_op = tensor->src[0]->op;
@@ -6983,8 +6998,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
69836998
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
69846999
}
69857000

6986-
tensor->backend = GGML_BACKEND_GPU;
6987-
69887001
if (scratch && no_alloc) {
69897002
return;
69907003
}
@@ -7069,6 +7082,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
70697082
tensor->extra = extra;
70707083
}
70717084

7085+
void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
7086+
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
7087+
GGML_ASSERT(ggml_is_contiguous(tensor));
7088+
7089+
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
7090+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7091+
CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
7092+
}
7093+
70727094
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
70737095
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
70747096
}

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
3131

3232
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
3333
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
34+
GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
3435

3536
GGML_API void ggml_cuda_set_main_device(int main_device);
3637
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);

ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
63436343
}
63446344

63456345
// make a view of the destination
6346-
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
6346+
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
63476347
if (strlen(b->name) > 0) {
63486348
ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
63496349
} else {

0 commit comments

Comments
 (0)