Skip to content

Commit 1123f7f

Browse files
ggml-cuda : use graph allocator (#2684)
use a different function for no_alloc to avoid breaking backwards compat, fixes lora remove 512 n_batch limit fixed 2048 batch size cleanup Co-authored-by: Johannes Gäßler <[email protected]>
1 parent ef3f333 commit 1123f7f

File tree

4 files changed

+92
-228
lines changed

4 files changed

+92
-228
lines changed

common/common.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
289289
break;
290290
}
291291
params.n_batch = std::stoi(argv[i]);
292-
params.n_batch = std::min(512, params.n_batch);
293292
} else if (arg == "--keep") {
294293
if (++i >= argc) {
295294
invalid_param = true;

ggml-cuda.cu

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
38873887
// rope == RoPE == rotary positional embedding
38883888
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
38893889
const float p_delta, const int p_delta_rows, const float theta_scale) {
3890-
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
3890+
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
38913891

38923892
if (col >= ncols) {
38933893
return;
38943894
}
38953895

3896-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3896+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
38973897
const int i = row*ncols + col;
38983898

38993899
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
39653965
}
39663966

39673967
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3968-
const int col = blockDim.x*blockIdx.x + threadIdx.x;
3969-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3968+
const int col = blockDim.y*blockIdx.y + threadIdx.y;
3969+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
39703970

39713971
if (col >= ncols) {
39723972
return;
@@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
39823982
// values are also not normalized to the maximum value by subtracting it in the exponential function
39833983
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
39843984
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
3985-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
3986-
const int block_size = blockDim.x;
3987-
const int tid = threadIdx.x;
3985+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
3986+
const int block_size = blockDim.y;
3987+
const int tid = threadIdx.y;
39883988

39893989
float tmp = 0.0;
39903990

@@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
47764776
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
47774777
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
47784778
GGML_ASSERT(nrows % 2 == 0);
4779-
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
4779+
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
47804780
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
4781-
const dim3 block_nums(num_blocks_x, nrows, 1);
4781+
const dim3 block_nums(nrows, num_blocks_x, 1);
47824782
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
47834783
}
47844784

@@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
48004800
}
48014801

48024802
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4803-
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
4803+
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
48044804
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
4805-
const dim3 block_nums(block_num_x, nrows_x, 1);
4805+
const dim3 block_nums(nrows_x, block_num_x, 1);
48064806
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
48074807
}
48084808

48094809
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
4810-
const dim3 block_dims(WARP_SIZE, 1, 1);
4811-
const dim3 block_nums(1, nrows_x, 1);
4810+
const dim3 block_dims(1, WARP_SIZE, 1);
4811+
const dim3 block_nums(nrows_x, 1, 1);
48124812
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
48134813
}
48144814

@@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
63136313
return extra;
63146314
}
63156315

6316-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
6316+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
63176317
if (scratch && g_scratch_size == 0) {
63186318
return;
63196319
}
@@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
63226322
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
63236323
const ggml_op src0_op = tensor->src[0]->op;
63246324
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
6325-
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
6325+
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
63266326
}
63276327
}
63286328
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
6329-
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
6329+
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
63306330
}
63316331

63326332
tensor->backend = GGML_BACKEND_GPU;
6333+
6334+
if (scratch && no_alloc) {
6335+
return;
6336+
}
6337+
63336338
struct ggml_tensor_extra_gpu * extra;
63346339

63356340
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
@@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
63816386
tensor->extra = extra;
63826387
}
63836388

6389+
void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
6390+
if (g_scratch_size == 0) {
6391+
return;
6392+
}
6393+
if (g_scratch_buffer == nullptr) {
6394+
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
6395+
}
6396+
6397+
struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
6398+
6399+
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
6400+
tensor->op == GGML_OP_VIEW;
6401+
6402+
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
6403+
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
6404+
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
6405+
size_t view_offset = 0;
6406+
if (tensor->op == GGML_OP_VIEW) {
6407+
memcpy(&view_offset, tensor->op_params, sizeof(size_t));
6408+
}
6409+
extra->data_device[g_main_device] = src0_ddc + view_offset;
6410+
} else {
6411+
extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
6412+
}
6413+
6414+
tensor->extra = extra;
6415+
}
6416+
63846417
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
6385-
ggml_cuda_assign_buffers_impl(tensor, true, false);
6418+
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
6419+
}
6420+
6421+
void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
6422+
ggml_cuda_assign_buffers_impl(tensor, true, false, true);
63866423
}
63876424

63886425
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
6389-
ggml_cuda_assign_buffers_impl(tensor, false, false);
6426+
ggml_cuda_assign_buffers_impl(tensor, false, false, false);
63906427
}
63916428

63926429
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
6393-
ggml_cuda_assign_buffers_impl(tensor, false, true);
6430+
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
63946431
}
63956432

63966433
void ggml_cuda_set_main_device(int main_device) {

ggml-cuda.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
1616
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
1717
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
1818
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19+
1920
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
2021
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
2122
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
23+
24+
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
25+
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
26+
2227
GGML_API void ggml_cuda_set_main_device(int main_device);
2328
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
2429
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);

0 commit comments

Comments
 (0)