@@ -212,6 +212,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
212
212
213
213
#define CUDA_ADD_BLOCK_SIZE 256
214
214
#define CUDA_MUL_BLOCK_SIZE 256
215
+ #define CUDA_GELU_BLOCK_SIZE 256
215
216
#define CUDA_SILU_BLOCK_SIZE 256
216
217
#define CUDA_CPY_BLOCK_SIZE 32
217
218
#define CUDA_SCALE_BLOCK_SIZE 256
@@ -266,6 +267,20 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
266
267
dst[i] = x[i] * y[i%ky];
267
268
}
268
269
270
+ static const float GELU_COEF_A = 0 .044715f ;
271
+ static const float SQRT_2_OVER_PI = 0 .79788456080286535587989211986876f ;
272
+
273
+ static __global__ void gelu_f32 (const float * x, float * dst, const int k) {
274
+ const int i = blockDim .x *blockIdx .x + threadIdx .x ;
275
+
276
+ if (i >= k) {
277
+ return ;
278
+ }
279
+
280
+ float xi = x[i];
281
+ dst[i] = 0 .5f *xi*(1 .0f + tanhf (SQRT_2_OVER_PI*xi*(1 .0f + GELU_COEF_A*xi*xi)));
282
+ }
283
+
269
284
static __global__ void silu_f32 (const float * x, float * dst, const int k) {
270
285
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
271
286
@@ -1733,6 +1748,11 @@ static void mul_f32_cuda(const float * x, const float * y, float * dst, const in
1733
1748
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0 , stream>>> (x, y, dst, kx, ky);
1734
1749
}
1735
1750
1751
+ static void gelu_f32_cuda (const float * x, float * dst, const int k, cudaStream_t stream) {
1752
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1 ) / CUDA_GELU_BLOCK_SIZE;
1753
+ gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0 , stream>>> (x, dst, k);
1754
+ }
1755
+
1736
1756
static void silu_f32_cuda (const float * x, float * dst, const int k, cudaStream_t stream) {
1737
1757
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1 ) / CUDA_SILU_BLOCK_SIZE;
1738
1758
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0 , stream>>> (x, dst, k);
@@ -2327,6 +2347,28 @@ inline void ggml_cuda_op_mul(
2327
2347
(void ) i02;
2328
2348
}
2329
2349
2350
+ inline void ggml_cuda_op_gelu (
2351
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2352
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
2353
+ cudaStream_t & cudaStream_main){
2354
+
2355
+ GGML_ASSERT (src0_ddf_i != nullptr );
2356
+ GGML_ASSERT (dst_ddf_i != nullptr );
2357
+
2358
+ const int64_t ne00 = src0->ne [0 ];
2359
+ const int64_t i01_diff = i01_high - i01_low;
2360
+
2361
+ // compute
2362
+ gelu_f32_cuda (src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2363
+
2364
+ (void ) src1;
2365
+ (void ) dst;
2366
+ (void ) src0_ddq_i;
2367
+ (void ) src1_ddf_i;
2368
+ (void ) i02;
2369
+ (void ) i1;
2370
+ }
2371
+
2330
2372
inline void ggml_cuda_op_silu (
2331
2373
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2332
2374
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2986,6 +3028,11 @@ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
2986
3028
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul, true , false ); // TODO ggml_cuda_op needs modification for flatten
2987
3029
}
2988
3030
3031
+ void ggml_cuda_gelu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3032
+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
3033
+ ggml_cuda_op (src0, src1, dst, ggml_cuda_op_gelu, true , true );
3034
+ }
3035
+
2989
3036
void ggml_cuda_silu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2990
3037
GGML_ASSERT (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2991
3038
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_silu, true , true );
@@ -3382,6 +3429,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
3382
3429
}
3383
3430
func = ggml_cuda_mul;
3384
3431
break ;
3432
+ case GGML_OP_GELU:
3433
+ if (!any_on_device) {
3434
+ return false ;
3435
+ }
3436
+ func = ggml_cuda_gelu;
3437
+ break ;
3385
3438
case GGML_OP_SILU:
3386
3439
if (!any_on_device) {
3387
3440
return false ;
0 commit comments