@@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
3887
3887
// rope == RoPE == rotary positional embedding
3888
3888
static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p0,
3889
3889
const float p_delta, const int p_delta_rows, const float theta_scale) {
3890
- const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
3890
+ const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
3891
3891
3892
3892
if (col >= ncols) {
3893
3893
return ;
3894
3894
}
3895
3895
3896
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3896
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3897
3897
const int i = row*ncols + col;
3898
3898
3899
3899
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
@@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
3965
3965
}
3966
3966
3967
3967
static __global__ void diag_mask_inf_f32 (const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3968
- const int col = blockDim .x *blockIdx .x + threadIdx .x ;
3969
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3968
+ const int col = blockDim .y *blockIdx .y + threadIdx .y ;
3969
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3970
3970
3971
3971
if (col >= ncols) {
3972
3972
return ;
@@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
3982
3982
// values are also not normalized to the maximum value by subtracting it in the exponential function
3983
3983
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
3984
3984
static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
3985
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3986
- const int block_size = blockDim .x ;
3987
- const int tid = threadIdx .x ;
3985
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3986
+ const int block_size = blockDim .y ;
3987
+ const int tid = threadIdx .y ;
3988
3988
3989
3989
float tmp = 0.0 ;
3990
3990
@@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
4776
4776
static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
4777
4777
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
4778
4778
GGML_ASSERT (nrows % 2 == 0 );
4779
- const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
4779
+ const dim3 block_dims (1 , 2 *CUDA_ROPE_BLOCK_SIZE, 1 );
4780
4780
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
4781
- const dim3 block_nums (num_blocks_x, nrows , 1 );
4781
+ const dim3 block_nums (nrows, num_blocks_x , 1 );
4782
4782
rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
4783
4783
}
4784
4784
@@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
4800
4800
}
4801
4801
4802
4802
static void diag_mask_inf_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4803
- const dim3 block_dims (CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1 , 1 );
4803
+ const dim3 block_dims (1 , CUDA_DIAG_MASK_INF_BLOCK_SIZE , 1 );
4804
4804
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1 ) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
4805
- const dim3 block_nums (block_num_x, nrows_x , 1 );
4805
+ const dim3 block_nums (nrows_x, block_num_x , 1 );
4806
4806
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
4807
4807
}
4808
4808
4809
4809
static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
4810
- const dim3 block_dims (WARP_SIZE, 1 , 1 );
4811
- const dim3 block_nums (1 , nrows_x , 1 );
4810
+ const dim3 block_dims (1 , WARP_SIZE , 1 );
4811
+ const dim3 block_nums (nrows_x, 1 , 1 );
4812
4812
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
4813
4813
}
4814
4814
@@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
6313
6313
return extra;
6314
6314
}
6315
6315
6316
- void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
6316
+ void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc ) {
6317
6317
if (scratch && g_scratch_size == 0 ) {
6318
6318
return ;
6319
6319
}
@@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
6322
6322
if (tensor->src [0 ] != nullptr && tensor->src [0 ]->backend == GGML_BACKEND_CPU) {
6323
6323
const ggml_op src0_op = tensor->src [0 ]->op ;
6324
6324
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
6325
- ggml_cuda_assign_buffers_impl (tensor->src [0 ], scratch, force_inplace);
6325
+ ggml_cuda_assign_buffers_impl (tensor->src [0 ], scratch, force_inplace, no_alloc );
6326
6326
}
6327
6327
}
6328
6328
if (tensor->op == GGML_OP_CPY && tensor->src [1 ]->backend == GGML_BACKEND_CPU) {
6329
- ggml_cuda_assign_buffers_impl (tensor->src [1 ], scratch, force_inplace);
6329
+ ggml_cuda_assign_buffers_impl (tensor->src [1 ], scratch, force_inplace, no_alloc );
6330
6330
}
6331
6331
6332
6332
tensor->backend = GGML_BACKEND_GPU;
6333
+
6334
+ if (scratch && no_alloc) {
6335
+ return ;
6336
+ }
6337
+
6333
6338
struct ggml_tensor_extra_gpu * extra;
6334
6339
6335
6340
const bool inplace = (tensor->src [0 ] != nullptr && tensor->src [0 ]->data == tensor->data ) ||
@@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
6381
6386
tensor->extra = extra;
6382
6387
}
6383
6388
6389
+ void ggml_cuda_assign_scratch_offset (struct ggml_tensor * tensor, size_t offset) {
6390
+ if (g_scratch_size == 0 ) {
6391
+ return ;
6392
+ }
6393
+ if (g_scratch_buffer == nullptr ) {
6394
+ CUDA_CHECK (cudaMalloc (&g_scratch_buffer, g_scratch_size));
6395
+ }
6396
+
6397
+ struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra ();
6398
+
6399
+ const bool inplace = (tensor->src [0 ] != nullptr && tensor->src [0 ]->data == tensor->data ) ||
6400
+ tensor->op == GGML_OP_VIEW;
6401
+
6402
+ if (inplace && (tensor->src [0 ]->backend == GGML_BACKEND_GPU || tensor->src [0 ]->backend == GGML_BACKEND_GPU_SPLIT)) {
6403
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src [0 ]->extra ;
6404
+ char * src0_ddc = (char *) src0_extra->data_device [g_main_device];
6405
+ size_t view_offset = 0 ;
6406
+ if (tensor->op == GGML_OP_VIEW) {
6407
+ memcpy (&view_offset, tensor->op_params , sizeof (size_t ));
6408
+ }
6409
+ extra->data_device [g_main_device] = src0_ddc + view_offset;
6410
+ } else {
6411
+ extra->data_device [g_main_device] = (char *) g_scratch_buffer + offset;
6412
+ }
6413
+
6414
+ tensor->extra = extra;
6415
+ }
6416
+
6384
6417
void ggml_cuda_assign_buffers (struct ggml_tensor * tensor) {
6385
- ggml_cuda_assign_buffers_impl (tensor, true , false );
6418
+ ggml_cuda_assign_buffers_impl (tensor, true , false , false );
6419
+ }
6420
+
6421
+ void ggml_cuda_assign_buffers_no_alloc (struct ggml_tensor * tensor) {
6422
+ ggml_cuda_assign_buffers_impl (tensor, true , false , true );
6386
6423
}
6387
6424
6388
6425
void ggml_cuda_assign_buffers_no_scratch (struct ggml_tensor * tensor) {
6389
- ggml_cuda_assign_buffers_impl (tensor, false , false );
6426
+ ggml_cuda_assign_buffers_impl (tensor, false , false , false );
6390
6427
}
6391
6428
6392
6429
void ggml_cuda_assign_buffers_force_inplace (struct ggml_tensor * tensor) {
6393
- ggml_cuda_assign_buffers_impl (tensor, false , true );
6430
+ ggml_cuda_assign_buffers_impl (tensor, false , true , false );
6394
6431
}
6395
6432
6396
6433
void ggml_cuda_set_main_device (int main_device) {
0 commit comments