Skip to content

Commit 4c32832

Browse files
authored
ggml : add ggml_gelu_erf() CUDA kernel (#13719)
* ggml : add ggml_gelu_erf() CUDA kernel * missing semicolon
1 parent c3a2624 commit 4c32832

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21922192
case GGML_UNARY_OP_SILU:
21932193
ggml_cuda_op_silu(ctx, dst);
21942194
break;
2195+
case GGML_UNARY_OP_GELU_ERF:
2196+
ggml_cuda_op_gelu_erf(ctx, dst);
2197+
break;
21952198
case GGML_UNARY_OP_GELU_QUICK:
21962199
ggml_cuda_op_gelu_quick(ctx, dst);
21972200
break;
@@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29772980
case GGML_UNARY_OP_SIGMOID:
29782981
case GGML_UNARY_OP_HARDSIGMOID:
29792982
case GGML_UNARY_OP_HARDSWISH:
2983+
case GGML_UNARY_OP_GELU_ERF:
29802984
case GGML_UNARY_OP_GELU_QUICK:
29812985
case GGML_UNARY_OP_TANH:
29822986
case GGML_UNARY_OP_EXP:

ggml/src/ggml-cuda/unary.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
2323
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
2424
}
2525

26+
static __device__ __forceinline__ float op_gelu_erf(float x) {
27+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
28+
29+
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
30+
}
31+
2632
static __device__ __forceinline__ float op_gelu_quick(float x) {
2733
const float GELU_QUICK_COEF = -1.702f;
2834

@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
134140
ggml_cuda_op_unary<op_gelu>(ctx, dst);
135141
}
136142

143+
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
144+
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
145+
}
146+
137147
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
138148
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
139149
}

ggml/src/ggml-cuda/unary.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3030

3131
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3232

33+
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34+
3335
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3436

3537
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)