Skip to content

Commit dd8ba93

Browse files
ggml: aarch64: Implement SVE F32 kernels for Mamba Sequential Scan Algorithm (#13882)
* F32-Mamba-Seq_Scan-SVE * Fix formatting * ggml : missing space --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 66c9206 commit dd8ba93

File tree

2 files changed

+110
-30
lines changed

2 files changed

+110
-30
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
76337633
const int ir1 = MIN(ir0 + dr, nr);
76347634
const int ir = ir1 - ir0;
76357635

7636-
for (int i3 = 0; i3 < n_s; ++i3) {
7637-
for (int i2 = 0; i2 < n_t; ++i2) {
7638-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7639-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7640-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7641-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7642-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7643-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7646-
7647-
// use the output as the source for the next token-wise iterations
7648-
if (i2 > 0) { s0 = s; }
7649-
7650-
// d_inner
7651-
for (int i1 = 0; i1 < ir; ++i1) {
7652-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654-
float x_dt = x[i1] * dt_soft_plus;
7655-
float sumf = 0.0f;
7656-
// d_state
7657-
for (int i0 = 0; i0 < nc; ++i0) {
7658-
int i = i0 + i1*nc;
7659-
// state = prev_state * dA + dB * x
7660-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661-
// y = rowwise_dotprod(state, C)
7662-
sumf += state * C[i0];
7663-
s[i] = state;
7636+
#ifdef __ARM_FEATURE_SVE
7637+
for (int i3 = 0; i3 < n_s; ++i3) {
7638+
for (int i2 = 0; i2 < n_t; ++i2) {
7639+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7640+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7641+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7642+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7643+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7644+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7645+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7646+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7647+
7648+
// use the output as the source for the next token-wise iterations
7649+
if (i2 > 0) { s0 = s; }
7650+
7651+
// d_inner
7652+
for (int i1 = 0; i1 < ir; ++i1) {
7653+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654+
float x_dt = x[i1] * dt_soft_plus;
7655+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7656+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7657+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658+
7659+
for (int64_t k = 0; k < nc; k += svcntw()) {
7660+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7661+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7662+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7663+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7664+
7665+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7666+
t1 = exp_ps_sve(svptrue_b32(), t1);
7667+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7668+
7669+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7670+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7671+
7672+
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
7673+
}
7674+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
76647675
}
7665-
y[i1] = sumf;
76667676
}
76677677
}
7668-
}
7678+
#else
7679+
for (int i3 = 0; i3 < n_s; ++i3) {
7680+
for (int i2 = 0; i2 < n_t; ++i2) {
7681+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7682+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7683+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7684+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7685+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7686+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7687+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7688+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7689+
7690+
// use the output as the source for the next token-wise iterations
7691+
if (i2 > 0) { s0 = s; }
7692+
7693+
// d_inner
7694+
for (int i1 = 0; i1 < ir; ++i1) {
7695+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7697+
float x_dt = x[i1] * dt_soft_plus;
7698+
float sumf = 0.0f;
7699+
// d_state
7700+
for (int i0 = 0; i0 < nc; ++i0) {
7701+
int i = i0 + i1*nc;
7702+
// state = prev_state * dA + dB * x
7703+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704+
// y = rowwise_dotprod(state, C)
7705+
sumf += state * C[i0];
7706+
s[i] = state;
7707+
}
7708+
y[i1] = sumf;
7709+
}
7710+
}
7711+
}
7712+
#endif
76697713
}
76707714

76717715
void ggml_compute_forward_ssm_scan(

ggml/src/ggml-cpu/vec.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
647647
#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
648648
#endif
649649

650+
/* Below function was borrowed from the GitHub repository:
651+
https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
652+
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
653+
inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
654+
// Constants
655+
const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
656+
const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
657+
const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
658+
const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
659+
const svfloat32_t one = svdup_n_f32(1.0f);
660+
const svfloat32_t inactive1 = svdup_n_f32(0.0f);
661+
const svint32_t inactive2 = svdup_n_s32(0);
662+
663+
// Algorithm starts here
664+
svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
665+
svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
666+
svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
667+
668+
t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
669+
t1 = svadd_f32_m(pg, t1, one); // b = a + 1
670+
671+
svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
672+
svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
673+
t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
674+
675+
// and_(t2.d, t1.d, not_mask17.d)
676+
svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
677+
t5 = svsub_f32_m(pg, t1, t5); // z
678+
t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
679+
t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
680+
t0 = svmul_f32_m(pg, t0, t4); // Final result
681+
682+
return t0;
683+
}
684+
#endif
685+
650686
#if defined(__ARM_NEON) && defined(__aarch64__)
651687

652688
// adapted from arm limited optimized routine

0 commit comments

Comments
 (0)