Skip to content

Commit 680e6f9

Browse files
committed
cuda : add gelu support
1 parent 4e7464e commit 680e6f9

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

ggml-cuda.cu

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
212212

213213
#define CUDA_ADD_BLOCK_SIZE 256
214214
#define CUDA_MUL_BLOCK_SIZE 256
215+
#define CUDA_GELU_BLOCK_SIZE 256
215216
#define CUDA_SILU_BLOCK_SIZE 256
216217
#define CUDA_CPY_BLOCK_SIZE 32
217218
#define CUDA_SCALE_BLOCK_SIZE 256
@@ -266,6 +267,20 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
266267
dst[i] = x[i] * y[i%ky];
267268
}
268269

270+
static const float GELU_COEF_A = 0.044715f;
271+
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
272+
273+
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
274+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
275+
276+
if (i >= k) {
277+
return;
278+
}
279+
280+
float xi = x[i];
281+
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
282+
}
283+
269284
static __global__ void silu_f32(const float * x, float * dst, const int k) {
270285
const int i = blockDim.x*blockIdx.x + threadIdx.x;
271286

@@ -1733,6 +1748,11 @@ static void mul_f32_cuda(const float * x, const float * y, float * dst, const in
17331748
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
17341749
}
17351750

1751+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
1752+
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
1753+
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
1754+
}
1755+
17361756
static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
17371757
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
17381758
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -2327,6 +2347,28 @@ inline void ggml_cuda_op_mul(
23272347
(void) i02;
23282348
}
23292349

2350+
inline void ggml_cuda_op_gelu(
2351+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2352+
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
2353+
cudaStream_t & cudaStream_main){
2354+
2355+
GGML_ASSERT(src0_ddf_i != nullptr);
2356+
GGML_ASSERT(dst_ddf_i != nullptr);
2357+
2358+
const int64_t ne00 = src0->ne[0];
2359+
const int64_t i01_diff = i01_high - i01_low;
2360+
2361+
// compute
2362+
gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2363+
2364+
(void) src1;
2365+
(void) dst;
2366+
(void) src0_ddq_i;
2367+
(void) src1_ddf_i;
2368+
(void) i02;
2369+
(void) i1;
2370+
}
2371+
23302372
inline void ggml_cuda_op_silu(
23312373
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
23322374
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2986,6 +3028,11 @@ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
29863028
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
29873029
}
29883030

3031+
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3032+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
3033+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
3034+
}
3035+
29893036
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29903037
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
29913038
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
@@ -3382,6 +3429,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
33823429
}
33833430
func = ggml_cuda_mul;
33843431
break;
3432+
case GGML_OP_GELU:
3433+
if (!any_on_device) {
3434+
return false;
3435+
}
3436+
func = ggml_cuda_gelu;
3437+
break;
33853438
case GGML_OP_SILU:
33863439
if (!any_on_device) {
33873440
return false;

0 commit comments

Comments
 (0)