@@ -16726,6 +16726,320 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16726
16726
size_t h_stride = C / H;
16727
16727
size_t h_stride_2d = head_size * head_size;
16728
16728
16729
+ #ifdef __AVX2__
16730
+ // AVX2 uses 256-bit vectors = 8 float32
16731
+ const int vec_size = 8;
16732
+ const size_t vec_count = head_size / vec_size;
16733
+
16734
+ for (size_t t = 0; t < T; t++) {
16735
+ size_t t_offset = t * t_stride;
16736
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
16737
+ float * state_cur = state + state_offset;
16738
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16739
+
16740
+ for (size_t h = 0; h < H; h++) {
16741
+ size_t h_offset = h * h_stride;
16742
+ size_t t_h_offset = t_offset + h_offset;
16743
+ size_t h_2d_offset = h * h_stride_2d;
16744
+
16745
+ for (size_t i = 0; i < head_size; i++) {
16746
+ size_t t_h_i_offset = t_h_offset + i;
16747
+ size_t h_i_offset = h_offset + i;
16748
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16749
+
16750
+ float k_val = k[t_h_i_offset];
16751
+ float r_val = r[t_h_i_offset];
16752
+ float time_faaaa_val = time_faaaa[h_i_offset];
16753
+ float time_decay_val = time_decay[t_h_i_offset];
16754
+
16755
+ // Broadcast scalar values to vectors
16756
+ __m256 k_vec = _mm256_set1_ps(k_val);
16757
+ __m256 r_vec = _mm256_set1_ps(r_val);
16758
+ __m256 time_faaaa_vec = _mm256_set1_ps(time_faaaa_val);
16759
+ __m256 time_decay_vec = _mm256_set1_ps(time_decay_val);
16760
+
16761
+ // Vector processing for chunks of 8 floats
16762
+ for (size_t j = 0; j < vec_count; j++) {
16763
+ size_t base_j = j * vec_size;
16764
+ size_t t_h_j_offset = t_h_offset + base_j;
16765
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16766
+
16767
+ // Load 8 elements at once
16768
+ __m256 v_vec = _mm256_loadu_ps(&v[t_h_j_offset]);
16769
+ __m256 prev_state_vec = _mm256_loadu_ps(&state_prev[h_2d_i_j_offset]);
16770
+ __m256 dst_vec = _mm256_loadu_ps(&dst_data[t_h_j_offset]);
16771
+
16772
+ // Compute kv = v * k
16773
+ __m256 kv_vec = _mm256_mul_ps(v_vec, k_vec);
16774
+
16775
+ // Compute temp = kv * time_faaaa + prev_state
16776
+ __m256 kv_time_vec = _mm256_mul_ps(kv_vec, time_faaaa_vec);
16777
+ __m256 temp_vec = _mm256_add_ps(kv_time_vec, prev_state_vec);
16778
+
16779
+ // Update dst: dst += temp * r
16780
+ __m256 result_vec = _mm256_mul_ps(temp_vec, r_vec);
16781
+ dst_vec = _mm256_add_ps(dst_vec, result_vec);
16782
+ _mm256_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
16783
+
16784
+ // Update state: state = prev_state * time_decay + kv
16785
+ __m256 decay_state_vec = _mm256_mul_ps(prev_state_vec, time_decay_vec);
16786
+ __m256 new_state_vec = _mm256_add_ps(decay_state_vec, kv_vec);
16787
+ _mm256_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
16788
+ }
16789
+
16790
+ // Handle remaining elements, this will not be used.
16791
+ for (size_t j = vec_count * vec_size; j < head_size; j++) {
16792
+ size_t t_h_j_offset = t_h_offset + j;
16793
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
16794
+
16795
+ float v_val = v[t_h_j_offset];
16796
+ float kv_val = v_val * k_val;
16797
+ float prev_state_val = state_prev[h_2d_i_j_offset];
16798
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
16799
+ dst_data[t_h_j_offset] += temp_val * r_val;
16800
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16801
+ }
16802
+ }
16803
+ }
16804
+ }
16805
+
16806
+ #elif __AVX512F__
16807
+ // AVX-512 uses 512-bit vectors = 16 float32
16808
+ const int vec_size = 16;
16809
+ const size_t vec_count = head_size / vec_size;
16810
+ const size_t vec_remain = head_size % vec_size;
16811
+
16812
+ for (size_t t = 0; t < T; t++) {
16813
+ size_t t_offset = t * t_stride;
16814
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
16815
+ float * state_cur = state + state_offset;
16816
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16817
+
16818
+ for (size_t h = 0; h < H; h++) {
16819
+ size_t h_offset = h * h_stride;
16820
+ size_t t_h_offset = t_offset + h_offset;
16821
+ size_t h_2d_offset = h * h_stride_2d;
16822
+
16823
+ for (size_t i = 0; i < head_size; i++) {
16824
+ size_t t_h_i_offset = t_h_offset + i;
16825
+ size_t h_i_offset = h_offset + i;
16826
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16827
+
16828
+ // Load scalar values
16829
+ float k_val = k[t_h_i_offset];
16830
+ float r_val = r[t_h_i_offset];
16831
+ float time_faaaa_val = time_faaaa[h_i_offset];
16832
+ float time_decay_val = time_decay[t_h_i_offset];
16833
+
16834
+ // Broadcast scalar values to ZMM registers (512-bit)
16835
+ __m512 k_vec = _mm512_set1_ps(k_val);
16836
+ __m512 r_vec = _mm512_set1_ps(r_val);
16837
+ __m512 time_faaaa_vec = _mm512_set1_ps(time_faaaa_val);
16838
+ __m512 time_decay_vec = _mm512_set1_ps(time_decay_val);
16839
+
16840
+ // Use prefetch to reduce cache misses
16841
+ #define PREFETCH_OFFSET 2
16842
+ if (i + PREFETCH_OFFSET < head_size) {
16843
+ _mm_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], _MM_HINT_T0);
16844
+ _mm_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], _MM_HINT_T0);
16845
+ }
16846
+
16847
+ // Vector processing for chunks of 16 floats
16848
+ for (size_t j = 0; j < vec_count; j++) {
16849
+ size_t base_j = j * vec_size;
16850
+ size_t t_h_j_offset = t_h_offset + base_j;
16851
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16852
+
16853
+ // Load 16 elements at once
16854
+ __m512 v_vec = _mm512_loadu_ps(&v[t_h_j_offset]);
16855
+ __m512 prev_state_vec = _mm512_loadu_ps(&state_prev[h_2d_i_j_offset]);
16856
+ __m512 dst_vec = _mm512_loadu_ps(&dst_data[t_h_j_offset]);
16857
+
16858
+ // Compute kv = v * k using FMA
16859
+ __m512 kv_vec = _mm512_mul_ps(v_vec, k_vec);
16860
+
16861
+ // Compute temp = kv * time_faaaa + prev_state using FMA
16862
+ __m512 temp_vec = _mm512_fmadd_ps(kv_vec, time_faaaa_vec, prev_state_vec);
16863
+
16864
+ // Update dst: dst += temp * r using FMA
16865
+ dst_vec = _mm512_fmadd_ps(temp_vec, r_vec, dst_vec);
16866
+ _mm512_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
16867
+
16868
+ // Update state: state = prev_state * time_decay + kv using FMA
16869
+ __m512 new_state_vec = _mm512_fmadd_ps(prev_state_vec, time_decay_vec, kv_vec);
16870
+ _mm512_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
16871
+ }
16872
+
16873
+ // Handle remaining elements, this will not be used.
16874
+ for (size_t j = vec_count * vec_size; j < head_size; j++) {
16875
+ size_t t_h_j_offset = t_h_offset + j;
16876
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
16877
+
16878
+ float v_val = v[t_h_j_offset];
16879
+ float kv_val = v_val * k_val;
16880
+ float prev_state_val = state_prev[h_2d_i_j_offset];
16881
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
16882
+ dst_data[t_h_j_offset] += temp_val * r_val;
16883
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16884
+ }
16885
+ }
16886
+ }
16887
+ }
16888
+
16889
+
16890
+ #elif __ARM_FEATURE_SVE
16891
+ // Get vector length for this CPU
16892
+ const size_t vec_size = svcntw();
16893
+
16894
+ for (size_t t = 0; t < T; t++) {
16895
+ size_t t_offset = t * t_stride;
16896
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
16897
+ float * state_cur = state + state_offset;
16898
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16899
+
16900
+ for (size_t h = 0; h < H; h++) {
16901
+ size_t h_offset = h * h_stride;
16902
+ size_t t_h_offset = t_offset + h_offset;
16903
+ size_t h_2d_offset = h * h_stride_2d;
16904
+
16905
+ for (size_t i = 0; i < head_size; i++) {
16906
+ size_t t_h_i_offset = t_h_offset + i;
16907
+ size_t h_i_offset = h_offset + i;
16908
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16909
+
16910
+ float k_val = k[t_h_i_offset];
16911
+ float r_val = r[t_h_i_offset];
16912
+ float time_faaaa_val = time_faaaa[h_i_offset];
16913
+ float time_decay_val = time_decay[t_h_i_offset];
16914
+
16915
+ // Create predicate for active lanes
16916
+ svbool_t pg = svwhilelt_b32(0, head_size);
16917
+
16918
+ // Process vectors until done
16919
+ size_t j = 0;
16920
+ while (svptest_first(svptrue_b32(), pg)) {
16921
+ size_t t_h_j_offset = t_h_offset + j;
16922
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
16923
+
16924
+ // Load vectors
16925
+ svfloat32_t v_vec = svld1_f32(pg, &v[t_h_j_offset]);
16926
+ svfloat32_t prev_state_vec = svld1_f32(pg, &state_prev[h_2d_i_j_offset]);
16927
+ svfloat32_t dst_vec = svld1_f32(pg, &dst_data[t_h_j_offset]);
16928
+
16929
+ // Compute kv = v * k
16930
+ svfloat32_t kv_vec = svmul_n_f32_x(pg, v_vec, k_val);
16931
+
16932
+ // Compute temp = kv * time_faaaa + prev_state
16933
+ svfloat32_t temp_vec = svmad_n_f32_x(pg, kv_vec, time_faaaa_val, prev_state_vec);
16934
+
16935
+ // Update dst: dst += temp * r
16936
+ svfloat32_t result_vec = svmad_n_f32_x(pg, temp_vec, r_val, dst_vec);
16937
+ svst1_f32(pg, &dst_data[t_h_j_offset], result_vec);
16938
+
16939
+ // Update state: state = prev_state * time_decay + kv
16940
+ svfloat32_t new_state_vec = svmad_n_f32_x(pg, prev_state_vec, time_decay_val, kv_vec);
16941
+ svst1_f32(pg, &state_cur[h_2d_i_j_offset], new_state_vec);
16942
+
16943
+ j += vec_size;
16944
+ pg = svwhilelt_b32(j, head_size);
16945
+ }
16946
+ }
16947
+ }
16948
+ }
16949
+
16950
+ #elif __ARM_NEON
16951
+ // NEON uses 128-bit vectors = 4 float32s
16952
+ const int vec_size = 4;
16953
+ const size_t vec_count = head_size / vec_size;
16954
+
16955
+ for (size_t t = 0; t < T; t++) {
16956
+ size_t t_offset = t * t_stride;
16957
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
16958
+ float * state_cur = state + state_offset;
16959
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16960
+
16961
+ for (size_t h = 0; h < H; h++) {
16962
+ size_t h_offset = h * h_stride;
16963
+ size_t t_h_offset = t_offset + h_offset;
16964
+ size_t h_2d_offset = h * h_stride_2d;
16965
+
16966
+ for (size_t i = 0; i < head_size; i++) {
16967
+ size_t t_h_i_offset = t_h_offset + i;
16968
+ size_t h_i_offset = h_offset + i;
16969
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16970
+
16971
+ float k_val = k[t_h_i_offset];
16972
+ float r_val = r[t_h_i_offset];
16973
+ float time_faaaa_val = time_faaaa[h_i_offset];
16974
+ float time_decay_val = time_decay[t_h_i_offset];
16975
+
16976
+ // Broadcast scalar values to vectors
16977
+ float32x4_t k_vec = vdupq_n_f32(k_val);
16978
+ float32x4_t r_vec = vdupq_n_f32(r_val);
16979
+ float32x4_t time_faaaa_vec = vdupq_n_f32(time_faaaa_val);
16980
+ float32x4_t time_decay_vec = vdupq_n_f32(time_decay_val);
16981
+
16982
+ // Use prefetch to reduce cache misses
16983
+ #ifdef __ARM_FEATURE_PREFETCH
16984
+ #define PREFETCH_OFFSET 2
16985
+ if (i + PREFETCH_OFFSET < head_size) {
16986
+ __builtin_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], 0, 3);
16987
+ __builtin_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], 0, 3);
16988
+ }
16989
+ #endif
16990
+
16991
+ // Vector processing for chunks of 4 floats
16992
+ for (size_t j = 0; j < vec_count; j++) {
16993
+ size_t base_j = j * vec_size;
16994
+ size_t t_h_j_offset = t_h_offset + base_j;
16995
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16996
+
16997
+ // Load 4 elements at once
16998
+ float32x4_t v_vec = vld1q_f32(&v[t_h_j_offset]);
16999
+ float32x4_t prev_state_vec = vld1q_f32(&state_prev[h_2d_i_j_offset]);
17000
+ float32x4_t dst_vec = vld1q_f32(&dst_data[t_h_j_offset]);
17001
+
17002
+ // Compute kv = v * k
17003
+ float32x4_t kv_vec = vmulq_f32(v_vec, k_vec);
17004
+
17005
+ // Compute temp = kv * time_faaaa + prev_state using FMA
17006
+ #ifdef __ARM_FEATURE_FMA
17007
+ float32x4_t temp_vec = vfmaq_f32(prev_state_vec, kv_vec, time_faaaa_vec);
17008
+ // Update dst: dst += temp * r
17009
+ dst_vec = vfmaq_f32(dst_vec, temp_vec, r_vec);
17010
+ // Update state: state = prev_state * time_decay + kv
17011
+ float32x4_t new_state_vec = vfmaq_f32(kv_vec, prev_state_vec, time_decay_vec);
17012
+ #else
17013
+ float32x4_t kv_time = vmulq_f32(kv_vec, time_faaaa_vec);
17014
+ float32x4_t temp_vec = vaddq_f32(kv_time, prev_state_vec);
17015
+ float32x4_t result_vec = vmulq_f32(temp_vec, r_vec);
17016
+ dst_vec = vaddq_f32(dst_vec, result_vec);
17017
+ float32x4_t decay_state_vec = vmulq_f32(prev_state_vec, time_decay_vec);
17018
+ float32x4_t new_state_vec = vaddq_f32(decay_state_vec, kv_vec);
17019
+ #endif
17020
+
17021
+ vst1q_f32(&dst_data[t_h_j_offset], dst_vec);
17022
+ vst1q_f32(&state_cur[h_2d_i_j_offset], new_state_vec);
17023
+ }
17024
+
17025
+ // Handle remaining elements
17026
+ for (size_t j = vec_count * vec_size; j < head_size; j++) {
17027
+ size_t t_h_j_offset = t_h_offset + j;
17028
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
17029
+
17030
+ float v_val = v[t_h_j_offset];
17031
+ float kv_val = v_val * k_val;
17032
+ float prev_state_val = state_prev[h_2d_i_j_offset];
17033
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
17034
+ dst_data[t_h_j_offset] += temp_val * r_val;
17035
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17036
+ }
17037
+ }
17038
+ }
17039
+ }
17040
+
17041
+
17042
+ #else
16729
17043
// basically fused operations:
16730
17044
// dst = r @ (time_faaaa * (k @ v) + state),
16731
17045
// state = time_decay * state + (k @ v),
@@ -16765,7 +17079,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16765
17079
}
16766
17080
}
16767
17081
}
17082
+
16768
17083
}
17084
+ #endif
16769
17085
}
16770
17086
16771
17087
0 commit comments