Skip to content

Commit 206e01d

Browse files
li-plusggerganov
andauthored
cuda : support broadcast add & mul (#2192)
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 4304bd3 commit 206e01d

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

ggml-cuda.cu

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,13 @@ struct ggml_tensor_extra_gpu {
252252
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
253253
};
254254

255-
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
255+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
256256
const int i = blockDim.x*blockIdx.x + threadIdx.x;
257257

258-
if (i >= k) {
258+
if (i >= kx) {
259259
return;
260260
}
261-
dst[i] = x[i] + y[i];
261+
dst[i] = x[i] + y[i%ky];
262262
}
263263

264264
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -1996,9 +1996,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
19961996
dst[i] = scale * x[i];
19971997
}
19981998

1999-
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
2000-
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
2001-
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1999+
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
2000+
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
2001+
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
20022002
}
20032003

20042004
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -2610,17 +2610,15 @@ inline void ggml_cuda_op_add(
26102610
GGML_ASSERT(src1_ddf_i != nullptr);
26112611
GGML_ASSERT(dst_ddf_i != nullptr);
26122612

2613-
// TODO: support broadcasting
2614-
GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
2615-
26162613
const int64_t ne00 = src0->ne[0];
26172614
const int64_t i01_diff = i01_high - i01_low;
26182615

2619-
// const int64_t ne10 = src1->ne[0];
2616+
const int64_t ne10 = src1->ne[0];
2617+
const int64_t ne11 = src1->ne[1];
26202618

26212619
// compute
26222620
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2623-
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2621+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
26242622
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
26252623
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
26262624
} else {
@@ -2644,19 +2642,12 @@ inline void ggml_cuda_op_mul(
26442642
GGML_ASSERT(dst_ddf_i != nullptr);
26452643

26462644
const int64_t ne00 = src0->ne[0];
2645+
const int64_t i01_diff = i01_high - i01_low;
2646+
26472647
const int64_t ne10 = src1->ne[0];
26482648
const int64_t ne11 = src1->ne[1];
26492649

2650-
for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
2651-
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
2652-
2653-
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
2654-
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
2655-
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
2656-
2657-
// compute
2658-
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2659-
}
2650+
mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
26602651

26612652
(void) dst;
26622653
(void) src0_ddq_i;

0 commit comments

Comments
 (0)