Skip to content

Commit e27f93e

Browse files
refactor CUDA implementation for ACC
1 parent 2f0b30d commit e27f93e

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

ggml/src/ggml-cuda/acc.cu

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,61 @@
11
#include "acc.cuh"
22

3-
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
4-
const int ne10, const int ne11, const int ne12,
5-
const int nb1, const int nb2, int offset) {
6-
const int i = blockDim.x * blockIdx.x + threadIdx.x;
3+
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
4+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
5+
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
6+
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
7+
78
if (i >= ne) {
89
return;
910
}
10-
int src1_idx = i - offset;
11-
int oz = src1_idx / nb2;
12-
int oy = (src1_idx - (oz * nb2)) / nb1;
13-
int ox = src1_idx % nb1;
14-
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
15-
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
16-
} else {
17-
dst[i] = x[i];
11+
12+
int64_t src1_idx = i - offset;
13+
14+
int64_t tmp = src1_idx;
15+
const int64_t i13 = tmp / s13;
16+
tmp -= i13 * s13;
17+
const int64_t i12 = tmp / s12;
18+
tmp -= i12 * s12;
19+
const int64_t i11 = tmp / s11;
20+
tmp -= i11 * s11;
21+
const int64_t i10 = tmp;
22+
23+
float val = x[i];
24+
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
25+
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
1826
}
27+
dst[i] = val;
1928
}
2029

21-
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
22-
const int ne10, const int ne11, const int ne12,
23-
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
24-
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
25-
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
30+
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
31+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
32+
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
33+
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
34+
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
2635
}
2736

2837
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2938
const ggml_tensor * src0 = dst->src[0];
3039
const ggml_tensor * src1 = dst->src[1];
31-
const float * src0_d = (const float *)src0->data;
32-
const float * src1_d = (const float *)src1->data;
33-
float * dst_d = (float *)dst->data;
40+
41+
const float * src0_d = (const float *) src0->data;
42+
const float * src1_d = (const float *) src1->data;
43+
float * dst_d = (float *) dst->data;
44+
3445
cudaStream_t stream = ctx.stream();
3546

3647
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3748
GGML_ASSERT(src1->type == GGML_TYPE_F32);
3849
GGML_ASSERT( dst->type == GGML_TYPE_F32);
39-
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
4050

41-
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
42-
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
43-
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
44-
int offset = dst->op_params[3] / 4; // offset in bytes
51+
GGML_ASSERT(ggml_is_contiguous(src1));
52+
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
53+
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
54+
55+
const int64_t s1 = dst->op_params[0] / sizeof(float);
56+
const int64_t s2 = dst->op_params[1] / sizeof(float);
57+
const int64_t s3 = dst->op_params[2] / sizeof(float);
58+
const int64_t offset = dst->op_params[3] / sizeof(float);
4559

46-
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
60+
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
4761
}

ggml/src/ggml-cuda/sum.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3131

3232
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3333
GGML_ASSERT( dst->type == GGML_TYPE_F32);
34-
GGML_ASSERT(ggml_is_contiguous(src0));
34+
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
3535

3636
const float * src0_d = (const float *) src0->data;
3737
float * dst_d = (float *) dst->data;

0 commit comments

Comments
 (0)