Skip to content

Commit 5ac4690

Browse files
committed
ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
1 parent c11d05f commit 5ac4690

File tree

6 files changed

+107
-34
lines changed

6 files changed

+107
-34
lines changed

ggml-cuda/softmax.cu

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#include "softmax.cuh"
22

3-
template <bool vals_smem, int ncols_template, int block_size_template>
4-
static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
3+
#include <cuda/std/type_traits>
4+
5+
template <typename T>
6+
static __device__ __forceinline__ float t2f32(T val) {
7+
return (float) val;
8+
}
9+
10+
template <>
11+
__device__ float __forceinline__ t2f32<half>(half val) {
12+
return __half2float(val);
13+
}
14+
15+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
16+
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
517
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
618

719
const int tid = threadIdx.x;
@@ -43,7 +55,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
4355
const int ix = rowx*ncols + col;
4456
const int iy = rowy*ncols + col;
4557

46-
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
58+
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
4759

4860
vals[col] = val;
4961
max_val = max(max_val, val);
@@ -114,7 +126,8 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
114126
}
115127
}
116128

117-
static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
129+
template<typename T>
130+
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
118131
int nth = WARP_SIZE;
119132
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
120133
const dim3 block_dims(nth, 1, 1);
@@ -167,15 +180,19 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p
167180
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168181
const ggml_tensor * src0 = dst->src[0];
169182
const ggml_tensor * src1 = dst->src[1];
183+
const ggml_tensor * src2 = dst->src[2];
184+
170185
const float * src0_d = (const float *)src0->data;
171-
const half * src1_d = src1 ? (const half *)src1->data : nullptr;
186+
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
187+
172188
float * dst_d = (float *)dst->data;
173189
cudaStream_t stream = ctx.stream();
174190

175191
GGML_ASSERT(src0->type == GGML_TYPE_F32);
176192
GGML_ASSERT( dst->type == GGML_TYPE_F32);
177193

178-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
194+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
195+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
179196

180197
const int64_t ne00 = src0->ne[0];
181198
const int64_t nrows_x = ggml_nrows(src0);
@@ -188,14 +205,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188205
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
189206

190207
// positions tensor
191-
half * src2_dd = nullptr;
208+
void * src2_d = nullptr;
192209

193-
ggml_tensor * src2 = dst->src[2];
194210
const bool use_src2 = src2 != nullptr;
195211

196212
if (use_src2) {
197-
src2_dd = (half *)src2->data;
213+
src2_d = (void *)src2->data;
198214
}
199215

200-
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
216+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
217+
218+
if (use_f16) {
219+
const half * src1_dd = (const half *)src1_d;
220+
const half * src2_dd = (const half *)src2_d;
221+
222+
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
223+
} else {
224+
const float * src1_dd = (const float *)src1_d;
225+
const float * src2_dd = (const float *)src2_d;
226+
227+
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
228+
}
201229
}

ggml-metal.m

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@
4646
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
4747
GGML_METAL_KERNEL_TYPE_SILU,
4848
GGML_METAL_KERNEL_TYPE_SILU_4,
49-
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
50-
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
49+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
50+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
51+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
52+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
5153
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
5254
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
5355
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -492,8 +494,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
492494
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
493495
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
494496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
495-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
496-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
497+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
498+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
499+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
500+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
497501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
498502
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
499503
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
@@ -1346,22 +1350,33 @@ static enum ggml_status ggml_metal_graph_compute(
13461350
} break;
13471351
case GGML_OP_SOFT_MAX:
13481352
{
1349-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
1353+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1354+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
13501355

13511356
int nth = 32; // SIMD width
13521357

13531358
id<MTLComputePipelineState> pipeline = nil;
13541359

1360+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1361+
13551362
if (ne00%4 == 0) {
13561363
while (nth < ne00/4 && nth < 256) {
13571364
nth *= 2;
13581365
}
1359-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
1366+
if (use_f16) {
1367+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
1368+
} else {
1369+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1370+
}
13601371
} else {
13611372
while (nth < ne00 && nth < 1024) {
13621373
nth *= 2;
13631374
}
1364-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1375+
if (use_f16) {
1376+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
1377+
} else {
1378+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
1379+
}
13651380
}
13661381

13671382
float scale;

ggml-metal.metal

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ kernel void kernel_sum_rows(
352352
dst_row[0] = row_sum;
353353
}
354354

355+
template<typename T>
355356
kernel void kernel_soft_max(
356357
device const char * src0,
357358
device const char * src1,
@@ -376,8 +377,8 @@ kernel void kernel_soft_max(
376377
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
377378

378379
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
379-
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
380-
device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr;
380+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
381382
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
382383

383384
float slope = 0.0f;
@@ -456,6 +457,7 @@ kernel void kernel_soft_max(
456457
}
457458
}
458459

460+
template<typename T>
459461
kernel void kernel_soft_max_4(
460462
device const char * src0,
461463
device const char * src1,
@@ -480,8 +482,8 @@ kernel void kernel_soft_max_4(
480482
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
481483

482484
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
483-
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
484-
device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr;
485+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
485487
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
486488

487489
float slope = 0.0f;
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
562564
}
563565
}
564566

567+
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
568+
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
569+
570+
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
571+
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
572+
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
573+
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
574+
565575
kernel void kernel_diag_mask_inf(
566576
device const float * src0,
567577
device float * dst,

ggml.c

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5473,18 +5473,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
54735473
GGML_ASSERT(ggml_is_contiguous(a));
54745474

54755475
if (mask) {
5476-
GGML_ASSERT(mask->type == GGML_TYPE_F16);
5476+
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
54775477
GGML_ASSERT(ggml_is_contiguous(mask));
54785478
GGML_ASSERT(ggml_is_matrix(mask));
54795479
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
54805480
}
54815481

54825482
if (pos) {
54835483
GGML_ASSERT(ggml_is_vector(pos));
5484-
GGML_ASSERT(pos->type == GGML_TYPE_F16);
5484+
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
54855485
GGML_ASSERT(pos->ne[0] == a->ne[0]);
54865486
}
54875487

5488+
if (pos && mask) {
5489+
GGML_ASSERT(pos->type == mask->type);
5490+
}
5491+
54885492
if (max_bias > 0.0f) {
54895493
GGML_ASSERT(pos);
54905494
}
@@ -12410,20 +12414,30 @@ static void ggml_compute_forward_soft_max_f32(
1241012414
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
1241112415

1241212416
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12413-
ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12417+
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12418+
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
12419+
12420+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1241412421

1241512422
for (int i1 = ir0; i1 < ir1; i1++) {
1241612423
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
1241712424
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
1241812425

1241912426
// broadcast the mask across rows
12420-
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12427+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12428+
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
1242112429

1242212430
ggml_vec_cpy_f32 (nc, wp, sp);
1242312431
ggml_vec_scale_f32(nc, wp, scale);
12424-
if (mp) {
12425-
for (int i = 0; i < nc; ++i) {
12426-
wp[i] += GGML_FP16_TO_FP32(mp[i]);
12432+
if (mp_f32) {
12433+
if (use_f16) {
12434+
for (int i = 0; i < nc; ++i) {
12435+
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
12436+
}
12437+
} else {
12438+
for (int i = 0; i < nc; ++i) {
12439+
wp[i] += mp_f32[i];
12440+
}
1242712441
}
1242812442
}
1242912443

@@ -12432,8 +12446,14 @@ static void ggml_compute_forward_soft_max_f32(
1243212446
const uint32_t h = (i1/ne01)%ne02; // head
1243312447
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
1243412448

12435-
for (int i = 0; i < nc; i++) {
12436-
wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]);
12449+
if (use_f16) {
12450+
for (int i = 0; i < nc; ++i) {
12451+
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
12452+
}
12453+
} else {
12454+
for (int i = 0; i < nc; ++i) {
12455+
wp[i] += slope*pos_f32[i];
12456+
}
1243712457
}
1243812458
}
1243912459

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6710,14 +6710,14 @@ struct llm_build_context {
67106710
}
67116711
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
67126712
ggml_set_input(lctx.inp_KQ_mask);
6713-
return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
6713+
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
67146714
}
67156715

67166716
struct ggml_tensor * build_inp_KQ_pos() {
67176717
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
67186718
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
67196719
ggml_set_input(lctx.inp_KQ_pos);
6720-
return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
6720+
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
67216721
}
67226722

67236723
struct ggml_tensor * build_inp_mean() {

tests/test-backend-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case {
11201120
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
11211121
ggml_tensor * mask = nullptr;
11221122
if (this->mask) {
1123-
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
1123+
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
11241124
}
11251125
ggml_tensor * pos = nullptr;
11261126
if (max_bias > 0.0f) {
1127-
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
1127+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
11281128
}
11291129
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
11301130
return out;

0 commit comments

Comments
 (0)