Skip to content

Commit f7d278f

Browse files
authored
ggml : revert CUDA broadcast changes from #2183 (#2191)
1 parent 20d7740 commit f7d278f

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
239239
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
240240
};
241241

242-
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
242+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
243243
const int i = blockDim.x*blockIdx.x + threadIdx.x;
244244

245-
if (i >= kx) {
245+
if (i >= k) {
246246
return;
247247
}
248-
dst[i] = x[i] + y[i%ky];
248+
dst[i] = x[i] + y[i];
249249
}
250250

251251
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -1718,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
17181718
dst[i] = scale * x[i];
17191719
}
17201720

1721-
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
1722-
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1723-
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
1721+
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
1722+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1723+
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
17241724
}
17251725

17261726
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -2272,7 +2272,10 @@ inline void ggml_cuda_op_add(
22722272

22732273
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
22742274
GGML_ASSERT(src1_ddf_i != nullptr);
2275-
GGML_ASSERT(dst_ddf_i != nullptr);
2275+
GGML_ASSERT(dst_ddf_i != nullptr);
2276+
2277+
// TODO: support broadcasting
2278+
GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
22762279

22772280
const int64_t ne00 = src0->ne[0];
22782281
const int64_t i01_diff = i01_high - i01_low;
@@ -2281,7 +2284,7 @@ inline void ggml_cuda_op_add(
22812284

22822285
// compute
22832286
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2284-
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
2287+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
22852288
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
22862289
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
22872290
} else {
@@ -2302,14 +2305,22 @@ inline void ggml_cuda_op_mul(
23022305

23032306
GGML_ASSERT(src0_ddf_i != nullptr);
23042307
GGML_ASSERT(src1_ddf_i != nullptr);
2305-
GGML_ASSERT(dst_ddf_i != nullptr);
2308+
GGML_ASSERT(dst_ddf_i != nullptr);
23062309

23072310
const int64_t ne00 = src0->ne[0];
2308-
const int64_t i01_diff = i01_high - i01_low;
2309-
23102311
const int64_t ne10 = src1->ne[0];
2312+
const int64_t ne11 = src1->ne[1];
23112313

2312-
mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
2314+
for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
2315+
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
2316+
2317+
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
2318+
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
2319+
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
2320+
2321+
// compute
2322+
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2323+
}
23132324

23142325
(void) dst;
23152326
(void) src0_ddq_i;

0 commit comments

Comments
 (0)