Skip to content

Commit 488c1fc

Browse files
committed
ggml-cuda : add rope f16, restore performance
1 parent 6028879 commit 488c1fc

File tree

4 files changed

+87
-55
lines changed

4 files changed

+87
-55
lines changed

ggml-cuda.cu

Lines changed: 78 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439439
struct ggml_tensor_extra_gpu {
440440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442-
bool copied;
443442
};
444443

445444
// this is faster on Windows
@@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43574356

43584357
// rope == RoPE == rotary positional embedding
43594358

4360-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361-
const int p_delta_rows, const float theta_scale) {
4359+
template<typename T, bool has_pos>
4360+
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361+
const int p_delta_rows, const float theta_scale) {
43624362
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43634363

43644364
if (col >= ncols) {
@@ -4369,7 +4369,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43694369
const int i = row*ncols + col;
43704370
const int i2 = row/p_delta_rows;
43714371

4372-
const int p = pos != nullptr ? pos[i2] : 0;
4372+
const int p = has_pos ? pos[i2] : 0;
43734373
const float p0 = p * freq_scale;
43744374
const float theta = p0*powf(theta_scale, col/2);
43754375
const float sin_theta = sinf(theta);
@@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43824382
dst[i + 1] = x0*sin_theta + x1*cos_theta;
43834383
}
43844384

4385-
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4386-
const int p_delta_rows, const float theta_scale) {
4385+
template<typename T, bool has_pos>
4386+
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4387+
const int p_delta_rows, const float theta_scale) {
43874388
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43884389

43894390
if (col >= ncols) {
@@ -4394,7 +4395,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43944395
const int i = row*ncols + col/2;
43954396
const int i2 = row/p_delta_rows;
43964397

4397-
const int p = pos != nullptr ? pos[i2] : 0;
4398+
const int p = has_pos ? pos[i2] : 0;
43984399
const float p0 = p * freq_scale;
43994400
const float theta = p0*powf(theta_scale, col/2);
44004401
const float sin_theta = sinf(theta);
@@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
53715372
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
53725373
}
53735374

5374-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375+
template<typename T>
5376+
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53755377
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53765378
GGML_ASSERT(ncols % 2 == 0);
53775379
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53785380
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53795381
const dim3 block_nums(nrows, num_blocks_x, 1);
5380-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5382+
if (pos == nullptr) {
5383+
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5384+
} else {
5385+
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5386+
}
53815387
}
53825388

5383-
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5389+
template<typename T>
5390+
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53845391
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53855392
GGML_ASSERT(ncols % 2 == 0);
53865393
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53875394
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53885395
const dim3 block_nums(nrows, num_blocks_x, 1);
5389-
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5396+
if (pos == nullptr) {
5397+
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5398+
} else {
5399+
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5400+
}
53905401
}
53915402

53925403
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
@@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
60366047
const int64_t ne0 = dst->ne[0];
60376048
const int64_t row_diff = row_high - row_low;
60386049

6039-
float * src0_ddq_as_f32;
6050+
float * src0_ddq_as_f32 = nullptr;
60406051
size_t src0_as = 0;
60416052

60426053
if (src0->type != GGML_TYPE_F32) {
@@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope(
60746085
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
60756086
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
60766087

6077-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6078-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
6088+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
6089+
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
6090+
GGML_ASSERT(src0->type == dst->type);
60796091

60806092
const int64_t ne00 = src0->ne[0];
60816093
const int64_t ne01 = src0->ne[1];
@@ -6093,23 +6105,16 @@ inline void ggml_cuda_op_rope(
60936105
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
60946106

60956107
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6096-
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
60976108

6098-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6099-
GGML_ASSERT(src1->ne[0] == ne2);
6100-
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
6101-
6102-
int id;
6103-
CUDA_CHECK(cudaGetDevice(&id));
6104-
6105-
int * pos = nullptr;
6109+
int32_t * pos = nullptr;
61066110
if ((mode & 1) == 0) {
6111+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6112+
GGML_ASSERT(src1->ne[0] == ne2);
6113+
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
61076114
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
6108-
pos = (int *) src1_extra->data_device[id];
6109-
if (!src1_extra->copied) {
6110-
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
6111-
src1_extra->copied = true;
6112-
}
6115+
int id;
6116+
CUDA_CHECK(cudaGetDevice(&id));
6117+
pos = (int32_t *) src1_extra->data_device[id];
61136118
}
61146119

61156120
const bool is_neox = mode & 2;
@@ -6121,9 +6126,21 @@ inline void ggml_cuda_op_rope(
61216126
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
61226127
} else if (is_neox) {
61236128
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6124-
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6129+
if (src0->type == GGML_TYPE_F32) {
6130+
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6131+
} else if (src0->type == GGML_TYPE_F16) {
6132+
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6133+
} else {
6134+
GGML_ASSERT(false);
6135+
}
61256136
} else {
6126-
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6137+
if (src0->type == GGML_TYPE_F32) {
6138+
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6139+
} else if (src0->type == GGML_TYPE_F16) {
6140+
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6141+
} else {
6142+
GGML_ASSERT(false);
6143+
}
61276144
}
61286145

61296146
(void) src1;
@@ -6294,7 +6311,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
62946311
}
62956312
}
62966313

6297-
void ggml_cuda_set_peer_access(const int n_tokens) {
6314+
static void ggml_cuda_set_peer_access(const int n_tokens) {
62986315
static bool peer_access_enabled = false;
62996316

63006317
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@@ -6622,27 +6639,27 @@ static void ggml_cuda_op_mul_mat(
66226639
}
66236640
}
66246641

6625-
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6642+
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66266643
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
66276644
}
66286645

6629-
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6646+
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66306647
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
66316648
}
66326649

6633-
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6650+
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66346651
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
66356652
}
66366653

6637-
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6654+
static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66386655
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
66396656
}
66406657

6641-
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6658+
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66426659
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
66436660
}
66446661

6645-
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6662+
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66466663
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
66476664
}
66486665

@@ -6663,7 +6680,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
66636680
return false;
66646681
}
66656682

6666-
void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6683+
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66676684
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
66686685
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
66696686
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@@ -6692,7 +6709,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
66926709
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
66936710
}
66946711

6695-
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6712+
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66966713
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
66976714
GGML_ASSERT(!ggml_is_permuted(src0));
66986715
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6726,7 +6743,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
67266743
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
67276744
}
67286745

6729-
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6746+
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67306747
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
67316748
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
67326749

@@ -6770,11 +6787,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
67706787
}
67716788
}
67726789

6773-
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6790+
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67746791
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
67756792
}
67766793

6777-
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6794+
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67786795
const int64_t ne = ggml_nelements(src0);
67796796
GGML_ASSERT(ne == ggml_nelements(src1));
67806797

@@ -6822,29 +6839,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
68226839
(void) dst;
68236840
}
68246841

6825-
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6842+
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68266843
ggml_cuda_cpy(src0, dst, nullptr);
68276844
(void) src1;
68286845
}
68296846

6830-
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6847+
static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68316848
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
68326849
}
68336850

6834-
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6851+
static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68356852
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
68366853
}
68376854

6838-
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6855+
static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68396856
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
68406857
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
68416858
}
68426859

6843-
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6860+
static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68446861
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
68456862
}
68466863

6847-
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6864+
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68486865
(void) src0;
68496866
(void) src1;
68506867
(void) dst;
@@ -6967,11 +6984,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
69676984
return extra;
69686985
}
69696986

6970-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6987+
static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
69716988
if (scratch && g_scratch_size == 0) {
69726989
return;
69736990
}
69746991

6992+
tensor->backend = GGML_BACKEND_GPU;
6993+
69756994
// recursively assign CUDA buffers until a compute tensor is found
69766995
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
69776996
const ggml_op src0_op = tensor->src[0]->op;
@@ -6983,8 +7002,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
69837002
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
69847003
}
69857004

6986-
tensor->backend = GGML_BACKEND_GPU;
6987-
69887005
if (scratch && no_alloc) {
69897006
return;
69907007
}
@@ -7069,6 +7086,16 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
70697086
tensor->extra = extra;
70707087
}
70717088

7089+
void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
7090+
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
7091+
GGML_ASSERT(ggml_is_contiguous(tensor));
7092+
7093+
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
7094+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7095+
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7096+
CUDA_CHECK(cudaMemcpyAsync(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice, main_stream));
7097+
}
7098+
70727099
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
70737100
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
70747101
}

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
3131

3232
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
3333
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
34+
GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
3435

3536
GGML_API void ggml_cuda_set_main_device(int main_device);
3637
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);

ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
63436343
}
63446344

63456345
// make a view of the destination
6346-
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
6346+
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
63476347
if (strlen(b->name) > 0) {
63486348
ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
63496349
} else {

llama.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,10 +1256,10 @@ static bool llama_kv_cache_init(
12561256

12571257
(void) n_gpu_layers;
12581258
#ifdef GGML_USE_CUBLAS
1259-
if (n_gpu_layers > n_layer + 1) {
1259+
if (n_gpu_layers > (int)n_layer + 1) {
12601260
ggml_cuda_assign_buffers_no_scratch(cache.v);
12611261
}
1262-
if (n_gpu_layers > n_layer + 2) {
1262+
if (n_gpu_layers > (int)n_layer + 2) {
12631263
ggml_cuda_assign_buffers_no_scratch(cache.k);
12641264
}
12651265
#endif // GGML_USE_CUBLAS
@@ -2619,7 +2619,7 @@ static struct ggml_cgraph * llm_build_llama(
26192619
const int n_gpu_layers = model.n_gpu_layers;
26202620

26212621
const int32_t n_tokens = batch.n_tokens;
2622-
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
2622+
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : std::max(1, (int)kv_self.cell_max);
26232623
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
26242624

26252625
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
@@ -2700,6 +2700,7 @@ static struct ggml_cgraph * llm_build_llama(
27002700

27012701
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
27022702
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2703+
offload_func_kq(KQ_mask);
27032704
ggml_allocr_alloc(lctx.alloc, KQ_mask);
27042705
if (!ggml_allocr_is_measure(lctx.alloc)) {
27052706
float * data = (float *) KQ_mask->data;
@@ -2722,6 +2723,7 @@ static struct ggml_cgraph * llm_build_llama(
27222723
// KQ_pos - contains the positions
27232724
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
27242725
offload_func_kq(KQ_pos);
2726+
ggml_set_name(KQ_pos, "KQ_pos");
27252727
ggml_allocr_alloc(lctx.alloc, KQ_pos);
27262728
if (!ggml_allocr_is_measure(lctx.alloc)) {
27272729
int * data = (int *) KQ_pos->data;
@@ -2734,6 +2736,7 @@ static struct ggml_cgraph * llm_build_llama(
27342736
if (do_rope_shift) {
27352737
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
27362738
offload_func_kq(K_shift);
2739+
ggml_set_name(K_shift, "K_shift");
27372740
ggml_allocr_alloc(lctx.alloc, K_shift);
27382741
if (!ggml_allocr_is_measure(lctx.alloc)) {
27392742
int * data = (int *) K_shift->data;
@@ -4116,6 +4119,7 @@ static int llama_decode_internal(
41164119
ggml_tensor * node = gf->leafs[i];
41174120
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
41184121
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
4122+
ggml_cuda_copy_to_device(node);
41194123
}
41204124
}
41214125

0 commit comments

Comments
 (0)