Skip to content

Commit 971bd95

Browse files
committed
ggml: CUDA unary op EXP
Signed-off-by: Molly Sophia <[email protected]>
1 parent bdf314f commit 971bd95

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22202220
case GGML_UNARY_OP_HARDSWISH:
22212221
ggml_cuda_op_hardswish(ctx, dst);
22222222
break;
2223+
case GGML_UNARY_OP_EXP:
2224+
ggml_cuda_op_exp(ctx, dst);
2225+
break;
22232226
default:
22242227
return false;
22252228
}
@@ -2749,6 +2752,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27492752
case GGML_UNARY_OP_HARDSWISH:
27502753
case GGML_UNARY_OP_GELU_QUICK:
27512754
case GGML_UNARY_OP_TANH:
2755+
case GGML_UNARY_OP_EXP:
27522756
return ggml_is_contiguous(op->src[0]);
27532757
default:
27542758
return false;

ggml/src/ggml-cuda/unary.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
7575
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
7676
}
7777

78+
static __global__ void exp_f32(const float * x, float * dst, const int k) {
79+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
80+
81+
if (i >= k) {
82+
return;
83+
}
84+
dst[i] = expf(x[i]);
85+
}
86+
7887
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
7988
const int i = blockDim.x*blockIdx.x + threadIdx.x;
8089
if (i >= k) {
@@ -159,6 +168,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
159168
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
160169
}
161170

171+
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
172+
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
173+
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
174+
}
175+
162176
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
163177
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
164178
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@@ -296,6 +310,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
296310
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
297311
}
298312

313+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
314+
const ggml_tensor * src0 = dst->src[0];
315+
const float * src0_d = (const float *)src0->data;
316+
float * dst_d = (float *)dst->data;
317+
cudaStream_t stream = ctx.stream();
318+
319+
GGML_ASSERT(ggml_is_contiguous(src0));
320+
321+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
322+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
323+
324+
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
325+
}
326+
299327
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
300328
const ggml_tensor * src0 = dst->src[0];
301329
const float * src0_d = (const float *)src0->data;

ggml/src/ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define CUDA_RELU_BLOCK_SIZE 256
77
#define CUDA_SIGMOID_BLOCK_SIZE 256
88
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
9+
#define CUDA_EXP_BLOCK_SIZE 256
910
#define CUDA_HARDSWISH_BLOCK_SIZE 256
1011
#define CUDA_SQR_BLOCK_SIZE 256
1112
#define CUDA_SQRT_BLOCK_SIZE 256
@@ -26,6 +27,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2627

2728
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2829

30+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
31+
2932
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3033

3134
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)