@@ -3886,13 +3886,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
3886
3886
// rope == RoPE == rotary positional embedding
3887
3887
static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p0,
3888
3888
const float p_delta, const int p_delta_rows, const float theta_scale) {
3889
- const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
3889
+ const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
3890
3890
3891
3891
if (col >= ncols) {
3892
3892
return ;
3893
3893
}
3894
3894
3895
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3895
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3896
3896
const int i = row*ncols + col;
3897
3897
3898
3898
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
@@ -3941,8 +3941,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
3941
3941
}
3942
3942
3943
3943
static __global__ void diag_mask_inf_f32 (const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3944
- const int col = blockDim .x *blockIdx .x + threadIdx .x ;
3945
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3944
+ const int col = blockDim .y *blockIdx .y + threadIdx .y ;
3945
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3946
3946
3947
3947
if (col >= ncols) {
3948
3948
return ;
@@ -3958,9 +3958,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
3958
3958
// values are also not normalized to the maximum value by subtracting it in the exponential function
3959
3959
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
3960
3960
static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
3961
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3962
- const int block_size = blockDim .x ;
3963
- const int tid = threadIdx .x ;
3961
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3962
+ const int block_size = blockDim .y ;
3963
+ const int tid = threadIdx .y ;
3964
3964
3965
3965
float tmp = 0.0 ;
3966
3966
@@ -4752,9 +4752,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
4752
4752
static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
4753
4753
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
4754
4754
GGML_ASSERT (nrows % 2 == 0 );
4755
- const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
4755
+ const dim3 block_dims (1 , 2 *CUDA_ROPE_BLOCK_SIZE, 1 );
4756
4756
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
4757
- const dim3 block_nums (num_blocks_x, nrows , 1 );
4757
+ const dim3 block_nums (nrows, num_blocks_x , 1 );
4758
4758
rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
4759
4759
}
4760
4760
@@ -4767,15 +4767,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
4767
4767
}
4768
4768
4769
4769
static void diag_mask_inf_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4770
- const dim3 block_dims (CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1 , 1 );
4770
+ const dim3 block_dims (1 , CUDA_DIAG_MASK_INF_BLOCK_SIZE , 1 );
4771
4771
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1 ) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
4772
- const dim3 block_nums (block_num_x, nrows_x , 1 );
4772
+ const dim3 block_nums (nrows_x, block_num_x , 1 );
4773
4773
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
4774
4774
}
4775
4775
4776
4776
static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
4777
- const dim3 block_dims (WARP_SIZE, 1 , 1 );
4778
- const dim3 block_nums (1 , nrows_x , 1 );
4777
+ const dim3 block_dims (1 , WARP_SIZE , 1 );
4778
+ const dim3 block_nums (nrows_x, 1 , 1 );
4779
4779
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
4780
4780
}
4781
4781
0 commit comments