Skip to content

Commit 2bfb39a

Browse files
ggml-cuda : use graph allocator
use a different function for no_alloc to avoid breaking backwards compat, fixes lora remove 512 n_batch limit fixed 2048 batch size cleanup Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 5e9ff54 commit 2bfb39a

File tree

4 files changed

+92
-228
lines changed

4 files changed

+92
-228
lines changed

examples/common.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
301301
break;
302302
}
303303
params.n_batch = std::stoi(argv[i]);
304-
params.n_batch = std::min(512, params.n_batch);
305304
} else if (arg == "--keep") {
306305
if (++i >= argc) {
307306
invalid_param = true;

ggml-cuda.cu

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,13 +3886,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
38863886
// rope == RoPE == rotary positional embedding
38873887
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
38883888
const float p_delta, const int p_delta_rows, const float theta_scale) {
3889-
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
3889+
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
38903890

38913891
if (col >= ncols) {
38923892
return;
38933893
}
38943894

3895-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3895+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
38963896
const int i = row*ncols + col;
38973897

38983898
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3941,8 +3941,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
39413941
}
39423942

39433943
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3944-
const int col = blockDim.x*blockIdx.x + threadIdx.x;
3945-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3944+
const int col = blockDim.y*blockIdx.y + threadIdx.y;
3945+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
39463946

39473947
if (col >= ncols) {
39483948
return;
@@ -3958,9 +3958,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
39583958
// values are also not normalized to the maximum value by subtracting it in the exponential function
39593959
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
39603960
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
3961-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3962-
const int block_size = blockDim.x;
3963-
const int tid = threadIdx.x;
3961+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
3962+
const int block_size = blockDim.y;
3963+
const int tid = threadIdx.y;
39643964

39653965
float tmp = 0.0;
39663966

@@ -4752,9 +4752,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
47524752
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
47534753
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
47544754
GGML_ASSERT(nrows % 2 == 0);
4755-
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
4755+
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
47564756
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
4757-
const dim3 block_nums(num_blocks_x, nrows, 1);
4757+
const dim3 block_nums(nrows, num_blocks_x, 1);
47584758
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
47594759
}
47604760

@@ -4767,15 +4767,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
47674767
}
47684768

47694769
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4770-
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
4770+
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
47714771
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
4772-
const dim3 block_nums(block_num_x, nrows_x, 1);
4772+
const dim3 block_nums(nrows_x, block_num_x, 1);
47734773
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
47744774
}
47754775

47764776
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
4777-
const dim3 block_dims(WARP_SIZE, 1, 1);
4778-
const dim3 block_nums(1, nrows_x, 1);
4777+
const dim3 block_dims(1, WARP_SIZE, 1);
4778+
const dim3 block_nums(nrows_x, 1, 1);
47794779
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
47804780
}
47814781

@@ -6240,7 +6240,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
62406240
return extra;
62416241
}
62426242

6243-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
6243+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
62446244
if (scratch && g_scratch_size == 0) {
62456245
return;
62466246
}
@@ -6249,14 +6249,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
62496249
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
62506250
const ggml_op src0_op = tensor->src[0]->op;
62516251
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
6252-
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
6252+
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
62536253
}
62546254
}
62556255
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
6256-
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
6256+
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
62576257
}
62586258

62596259
tensor->backend = GGML_BACKEND_GPU;
6260+
6261+
if (scratch && no_alloc) {
6262+
return;
6263+
}
6264+
62606265
struct ggml_tensor_extra_gpu * extra;
62616266

62626267
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
@@ -6308,16 +6313,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
63086313
tensor->extra = extra;
63096314
}
63106315

6316+
void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
6317+
if (g_scratch_size == 0) {
6318+
return;
6319+
}
6320+
if (g_scratch_buffer == nullptr) {
6321+
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
6322+
}
6323+
6324+
struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
6325+
6326+
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
6327+
tensor->op == GGML_OP_VIEW;
6328+
6329+
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
6330+
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
6331+
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
6332+
size_t view_offset = 0;
6333+
if (tensor->op == GGML_OP_VIEW) {
6334+
memcpy(&view_offset, tensor->op_params, sizeof(size_t));
6335+
}
6336+
extra->data_device[g_main_device] = src0_ddc + view_offset;
6337+
} else {
6338+
extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
6339+
}
6340+
6341+
tensor->extra = extra;
6342+
}
6343+
63116344
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
6312-
ggml_cuda_assign_buffers_impl(tensor, true, false);
6345+
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
6346+
}
6347+
6348+
void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
6349+
ggml_cuda_assign_buffers_impl(tensor, true, false, true);
63136350
}
63146351

63156352
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
6316-
ggml_cuda_assign_buffers_impl(tensor, false, false);
6353+
ggml_cuda_assign_buffers_impl(tensor, false, false, false);
63176354
}
63186355

63196356
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
6320-
ggml_cuda_assign_buffers_impl(tensor, false, true);
6357+
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
63216358
}
63226359

63236360
void ggml_cuda_set_main_device(int main_device) {

ggml-cuda.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
1616
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
1717
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
1818
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19+
1920
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
2021
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
2122
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
23+
24+
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
25+
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
26+
2227
GGML_API void ggml_cuda_set_main_device(int main_device);
2328
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
2429
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);

0 commit comments

Comments
 (0)