Skip to content

Commit 2bd1b02

Browse files
fixed 2048 batch size
1 parent 188a285 commit 2bd1b02

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ggml-cuda.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,13 +3886,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
38863886
// rope == RoPE == rotary positional embedding
38873887
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
38883888
const float p_delta, const int p_delta_rows, const float theta_scale) {
3889-
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
3889+
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
38903890

38913891
if (col >= ncols) {
38923892
return;
38933893
}
38943894

3895-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3895+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
38963896
const int i = row*ncols + col;
38973897

38983898
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3941,8 +3941,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
39413941
}
39423942

39433943
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3944-
const int col = blockDim.x*blockIdx.x + threadIdx.x;
3945-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3944+
const int col = blockDim.y*blockIdx.y + threadIdx.y;
3945+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
39463946

39473947
if (col >= ncols) {
39483948
return;
@@ -3958,9 +3958,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
39583958
// values are also not normalized to the maximum value by subtracting it in the exponential function
39593959
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
39603960
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
3961-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3962-
const int block_size = blockDim.x;
3963-
const int tid = threadIdx.x;
3961+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
3962+
const int block_size = blockDim.y;
3963+
const int tid = threadIdx.y;
39643964

39653965
float tmp = 0.0;
39663966

@@ -4752,9 +4752,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
47524752
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
47534753
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
47544754
GGML_ASSERT(nrows % 2 == 0);
4755-
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
4755+
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
47564756
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
4757-
const dim3 block_nums(num_blocks_x, nrows, 1);
4757+
const dim3 block_nums(nrows, num_blocks_x, 1);
47584758
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
47594759
}
47604760

@@ -4767,15 +4767,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
47674767
}
47684768

47694769
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4770-
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
4770+
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
47714771
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
4772-
const dim3 block_nums(block_num_x, nrows_x, 1);
4772+
const dim3 block_nums(nrows_x, block_num_x, 1);
47734773
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
47744774
}
47754775

47764776
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
4777-
const dim3 block_dims(WARP_SIZE, 1, 1);
4778-
const dim3 block_nums(1, nrows_x, 1);
4777+
const dim3 block_dims(1, WARP_SIZE, 1);
4778+
const dim3 block_nums(nrows_x, 1, 1);
47794779
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
47804780
}
47814781

0 commit comments

Comments
 (0)