@@ -13667,6 +13667,184 @@ static void ggml_compute_forward_gla(
13667
13667
}
13668
13668
}
13669
13669
13670
+ // ggml_compute_forward_rwkv_wkv7
13671
+
13672
+ static void ggml_compute_forward_rwkv_wkv7_f32(
13673
+ const struct ggml_compute_params * params,
13674
+ struct ggml_tensor * dst) {
13675
+ const int64_t T = dst->src[1]->ne[2];
13676
+ const int64_t C = dst->ne[0];
13677
+ const int64_t HEADS = dst->src[1]->ne[1];
13678
+ const int64_t n_seqs = dst->src[6]->ne[1];
13679
+ const int64_t head_size = C / HEADS;
13680
+
13681
+ float * dst_data = (float *) dst->data;
13682
+ float * state = ((float *) dst->data) + C * T;
13683
+
13684
+ const int ith = params->ith;
13685
+ const int nth = params->nth;
13686
+
13687
+ if (ith >= HEADS) {
13688
+ return;
13689
+ }
13690
+
13691
+ const int h_start = (HEADS * ith) / nth;
13692
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
13693
+ (HEADS * (ith + 1)) / nth : HEADS;
13694
+
13695
+ float * r = (float *) dst->src[0]->data;
13696
+ float * w = (float *) dst->src[1]->data;
13697
+ float * k = (float *) dst->src[2]->data;
13698
+ float * v = (float *) dst->src[3]->data;
13699
+ float * a = (float *) dst->src[4]->data;
13700
+ float * b = (float *) dst->src[5]->data;
13701
+
13702
+ int64_t t_stride = HEADS * head_size; // Same to C
13703
+
13704
+ int64_t h_stride = C / HEADS;
13705
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
13706
+ int64_t h_stride_2d = head_size * head_size;
13707
+
13708
+ #if defined(GGML_SIMD)
13709
+ for (int64_t t = 0; t < T; t++) {
13710
+ int64_t t_offset = t * t_stride;
13711
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13712
+ float * state_cur = state + state_offset;
13713
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13714
+
13715
+ for (int64_t h = h_start; h < h_end; h++) {
13716
+ int64_t h_offset = h * h_stride;
13717
+ int64_t t_h_offset = t_offset + h_offset;
13718
+ int64_t h_2d_offset = h * h_stride_2d;
13719
+
13720
+ for (int64_t ii = 0; ii < head_size; ii++) {
13721
+ int64_t t_h_i_offset = t_h_offset + ii;
13722
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
13723
+
13724
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
13725
+
13726
+ float sa = 0;
13727
+ {
13728
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13729
+ GGML_F32_VEC ax[GGML_F32_ARR];
13730
+ GGML_F32_VEC ay[GGML_F32_ARR];
13731
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
13732
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13733
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
13734
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
13735
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
13736
+ }
13737
+ }
13738
+ GGML_F32_VEC_REDUCE(sa, sum);
13739
+ }
13740
+
13741
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
13742
+
13743
+ int64_t j = 0;
13744
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13745
+ for (; j < head_size; j += GGML_F32_STEP) {
13746
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13747
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
13748
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
13749
+
13750
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
13751
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
13752
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
13753
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
13754
+
13755
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
13756
+
13757
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
13758
+ // kv + s * decay + sa * b
13759
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
13760
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
13761
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
13762
+
13763
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
13764
+ }
13765
+ }
13766
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
13767
+
13768
+ // There shouldn't be left-overs though.
13769
+ for (; j < head_size; j++) {
13770
+ int64_t t_h_j_offset = t_h_offset + j;
13771
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13772
+
13773
+ float r_val = r[t_h_j_offset];
13774
+ float w_val = w[t_h_j_offset];
13775
+ float k_val = k[t_h_j_offset];
13776
+ float b_val = b[t_h_j_offset];
13777
+ float kv_val = v[t_h_i_offset] * k_val;
13778
+
13779
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13780
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13781
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
13782
+ }
13783
+ }
13784
+ }
13785
+ }
13786
+ #else
13787
+ for (int64_t t = 0; t < T; t++) {
13788
+ int64_t t_offset = t * t_stride;
13789
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13790
+ float * state_cur = state + state_offset;
13791
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13792
+
13793
+ for (int64_t h = h_start; h < h_end; h++) {
13794
+ int64_t h_offset = h * h_stride;
13795
+ int64_t t_h_offset = t_offset + h_offset;
13796
+ int64_t h_2d_offset = h * h_stride_2d;
13797
+
13798
+ for (int64_t i = 0; i < head_size; i++) {
13799
+ int64_t t_h_i_offset = t_h_offset + i;
13800
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
13801
+
13802
+ float v_val = v[t_h_i_offset];
13803
+
13804
+ float sa = 0, result = 0;
13805
+ for (int64_t j = 0; j < head_size; j++) {
13806
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
13807
+ }
13808
+
13809
+ for (int64_t j = 0; j < head_size; j++) {
13810
+ int64_t t_h_j_offset = t_h_offset + j;
13811
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13812
+
13813
+ float r_val = r[t_h_j_offset];
13814
+ float w_val = w[t_h_j_offset];
13815
+ float k_val = k[t_h_j_offset];
13816
+ float b_val = b[t_h_j_offset];
13817
+ float kv_val = v_val * k_val;
13818
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13819
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13820
+ result += state_cur[h_2d_i_j_offset] * r_val;
13821
+ }
13822
+ dst_data[t_h_i_offset] = result;
13823
+ }
13824
+ }
13825
+ }
13826
+ #endif
13827
+ }
13828
+
13829
+
13830
+ static void ggml_compute_forward_rwkv_wkv7(
13831
+ const struct ggml_compute_params * params,
13832
+ struct ggml_tensor * dst) {
13833
+
13834
+ const struct ggml_tensor * src0 = dst->src[0];
13835
+
13836
+ switch (src0->type) {
13837
+ case GGML_TYPE_F32:
13838
+ {
13839
+ ggml_compute_forward_rwkv_wkv7_f32(params, dst);
13840
+ } break;
13841
+ default:
13842
+ {
13843
+ GGML_ABORT("fatal error");
13844
+ }
13845
+ }
13846
+ }
13847
+
13670
13848
// ggml_compute_forward_map_unary
13671
13849
13672
13850
static void ggml_compute_forward_map_unary_f32(
@@ -14424,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14424
14602
{
14425
14603
ggml_compute_forward_gla(params, tensor);
14426
14604
} break;
14605
+ case GGML_OP_RWKV_WKV7:
14606
+ {
14607
+ ggml_compute_forward_rwkv_wkv7(params, tensor);
14608
+ } break;
14427
14609
case GGML_OP_MAP_UNARY:
14428
14610
{
14429
14611
ggml_unary_op_f32_t fun;
@@ -14716,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14716
14898
case GGML_OP_FLASH_ATTN_BACK:
14717
14899
case GGML_OP_SSM_CONV:
14718
14900
case GGML_OP_SSM_SCAN:
14901
+ case GGML_OP_RWKV_WKV6:
14902
+ case GGML_OP_GATED_LINEAR_ATTN:
14903
+ case GGML_OP_RWKV_WKV7:
14719
14904
{
14720
14905
n_tasks = n_threads;
14721
14906
} break;
14722
14907
case GGML_OP_WIN_PART:
14723
14908
case GGML_OP_WIN_UNPART:
14724
14909
case GGML_OP_GET_REL_POS:
14725
- case GGML_OP_RWKV_WKV6:
14726
- case GGML_OP_GATED_LINEAR_ATTN:
14727
14910
case GGML_OP_MAP_UNARY:
14728
14911
case GGML_OP_MAP_BINARY:
14729
14912
case GGML_OP_MAP_CUSTOM1_F32:
0 commit comments