@@ -3077,7 +3077,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3077
3077
"WIN_UNPART",
3078
3078
"GET_REL_POS",
3079
3079
"ADD_REL_POS",
3080
- "RWKV_WKV ",
3080
+ "RWKV_WKV6 ",
3081
3081
3082
3082
"UNARY",
3083
3083
@@ -16709,11 +16709,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16709
16709
float * dst_data = (float *) dst->data;
16710
16710
float * state = ((float *) dst->data) + C * T;
16711
16711
16712
- if (params->ith != 0 ) {
16712
+ if ((size_t) params->ith >= H ) {
16713
16713
return;
16714
16714
}
16715
16715
16716
- memset(dst_data, 0, T * C * sizeof(float));
16716
+ size_t h_start = (H * params->ith) / params->nth;
16717
+ size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
16718
+ (H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
16717
16719
16718
16720
float * k = (float *) dst->src[0]->data;
16719
16721
float * v = (float *) dst->src[1]->data;
@@ -16726,6 +16728,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16726
16728
size_t h_stride = C / H;
16727
16729
size_t h_stride_2d = head_size * head_size;
16728
16730
16731
+ if (params->ith == 0) {
16732
+ memset(dst_data, 0, T * C * sizeof(float));
16733
+ }
16734
+ ggml_barrier(params->threadpool);
16735
+
16736
+
16737
+
16729
16738
#ifdef __AVX2__
16730
16739
// AVX2 uses 256-bit vectors = 8 float32
16731
16740
const int vec_size = 8;
@@ -16737,7 +16746,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16737
16746
float * state_cur = state + state_offset;
16738
16747
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16739
16748
16740
- for (size_t h = 0 ; h < H ; h++) {
16749
+ for (size_t h = h_start ; h < h_end ; h++) {
16741
16750
size_t h_offset = h * h_stride;
16742
16751
size_t t_h_offset = t_offset + h_offset;
16743
16752
size_t h_2d_offset = h * h_stride_2d;
@@ -16815,7 +16824,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16815
16824
float * state_cur = state + state_offset;
16816
16825
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16817
16826
16818
- for (size_t h = 0 ; h < H ; h++) {
16827
+ for (size_t h = h_start ; h < h_end ; h++) {
16819
16828
size_t h_offset = h * h_stride;
16820
16829
size_t t_h_offset = t_offset + h_offset;
16821
16830
size_t h_2d_offset = h * h_stride_2d;
@@ -16897,7 +16906,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16897
16906
float * state_cur = state + state_offset;
16898
16907
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16899
16908
16900
- for (size_t h = 0 ; h < H ; h++) {
16909
+ for (size_t h = h_start ; h < h_end ; h++) {
16901
16910
size_t h_offset = h * h_stride;
16902
16911
size_t t_h_offset = t_offset + h_offset;
16903
16912
size_t h_2d_offset = h * h_stride_2d;
@@ -16958,7 +16967,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
16958
16967
float * state_cur = state + state_offset;
16959
16968
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16960
16969
16961
- for (size_t h = 0 ; h < H ; h++) {
16970
+ for (size_t h = h_start ; h < h_end ; h++) {
16962
16971
size_t h_offset = h * h_stride;
16963
16972
size_t t_h_offset = t_offset + h_offset;
16964
16973
size_t h_2d_offset = h * h_stride_2d;
@@ -17050,7 +17059,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
17050
17059
float * state_cur = state + state_offset;
17051
17060
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
17052
17061
17053
- for (size_t h = 0 ; h < H ; h++) {
17062
+ for (size_t h = h_start ; h < h_end ; h++) {
17054
17063
size_t h_offset = h * h_stride;
17055
17064
size_t t_h_offset = t_offset + h_offset;
17056
17065
size_t h_2d_offset = h * h_stride_2d;
0 commit comments