Skip to content

Commit 92f44ff

Browse files
authored
metal : add GELU implementation (#1770)
Co-authored-by: Adam Treat <[email protected]>
1 parent 245fc3c commit 92f44ff

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

ggml-metal.m

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
GGML_METAL_DECL_KERNEL(scale);
4646
GGML_METAL_DECL_KERNEL(silu);
4747
GGML_METAL_DECL_KERNEL(relu);
48+
GGML_METAL_DECL_KERNEL(gelu);
4849
GGML_METAL_DECL_KERNEL(soft_max);
4950
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5051
GGML_METAL_DECL_KERNEL(get_rows_f16);
@@ -135,6 +136,7 @@
135136
GGML_METAL_ADD_KERNEL(scale);
136137
GGML_METAL_ADD_KERNEL(silu);
137138
GGML_METAL_ADD_KERNEL(relu);
139+
GGML_METAL_ADD_KERNEL(gelu);
138140
GGML_METAL_ADD_KERNEL(soft_max);
139141
GGML_METAL_ADD_KERNEL(diag_mask_inf);
140142
GGML_METAL_ADD_KERNEL(get_rows_f16);
@@ -420,6 +422,20 @@ void ggml_metal_graph_compute(
420422

421423
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
422424
} break;
425+
case GGML_OP_GELU:
426+
{
427+
if (encoder == nil) {
428+
encoder = [command_buffer computeCommandEncoder];
429+
}
430+
431+
[encoder setComputePipelineState:ctx->pipeline_gelu];
432+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
433+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
434+
435+
const int64_t n = ggml_nelements(dst);
436+
437+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
438+
} break;
423439
case GGML_OP_SOFT_MAX:
424440
{
425441
if (encoder == nil) {

ggml-metal.metal

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ kernel void kernel_relu(
8181
dst[tpig] = max(0.0f, src0[tpig]);
8282
}
8383

84+
constant float GELU_COEF_A = 0.044715f;
85+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
86+
87+
kernel void kernel_gelu(
88+
device const float * src0,
89+
device float * dst,
90+
uint tpig[[thread_position_in_grid]]) {
91+
float x = src0[tpig];
92+
dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
93+
}
94+
8495
kernel void kernel_soft_max(
8596
device const float * src0,
8697
device float * dst,

0 commit comments

Comments
 (0)