Skip to content

Commit ae39cb0

Browse files
committed
rwkv6: support avx2 avx512 armv8 armv9
1 parent 35405cd commit ae39cb0

File tree

1 file changed

+316
-0
lines changed

1 file changed

+316
-0
lines changed

ggml/src/ggml.c

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16726,6 +16726,320 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1672616726
size_t h_stride = C / H;
1672716727
size_t h_stride_2d = head_size * head_size;
1672816728

16729+
#ifdef __AVX2__
16730+
// AVX2 uses 256-bit vectors = 8 float32
16731+
const int vec_size = 8;
16732+
const size_t vec_count = head_size / vec_size;
16733+
16734+
for (size_t t = 0; t < T; t++) {
16735+
size_t t_offset = t * t_stride;
16736+
size_t state_offset = head_size * C * (t / (T / n_seqs));
16737+
float * state_cur = state + state_offset;
16738+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16739+
16740+
for (size_t h = 0; h < H; h++) {
16741+
size_t h_offset = h * h_stride;
16742+
size_t t_h_offset = t_offset + h_offset;
16743+
size_t h_2d_offset = h * h_stride_2d;
16744+
16745+
for (size_t i = 0; i < head_size; i++) {
16746+
size_t t_h_i_offset = t_h_offset + i;
16747+
size_t h_i_offset = h_offset + i;
16748+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16749+
16750+
float k_val = k[t_h_i_offset];
16751+
float r_val = r[t_h_i_offset];
16752+
float time_faaaa_val = time_faaaa[h_i_offset];
16753+
float time_decay_val = time_decay[t_h_i_offset];
16754+
16755+
// Broadcast scalar values to vectors
16756+
__m256 k_vec = _mm256_set1_ps(k_val);
16757+
__m256 r_vec = _mm256_set1_ps(r_val);
16758+
__m256 time_faaaa_vec = _mm256_set1_ps(time_faaaa_val);
16759+
__m256 time_decay_vec = _mm256_set1_ps(time_decay_val);
16760+
16761+
// Vector processing for chunks of 8 floats
16762+
for (size_t j = 0; j < vec_count; j++) {
16763+
size_t base_j = j * vec_size;
16764+
size_t t_h_j_offset = t_h_offset + base_j;
16765+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16766+
16767+
// Load 8 elements at once
16768+
__m256 v_vec = _mm256_loadu_ps(&v[t_h_j_offset]);
16769+
__m256 prev_state_vec = _mm256_loadu_ps(&state_prev[h_2d_i_j_offset]);
16770+
__m256 dst_vec = _mm256_loadu_ps(&dst_data[t_h_j_offset]);
16771+
16772+
// Compute kv = v * k
16773+
__m256 kv_vec = _mm256_mul_ps(v_vec, k_vec);
16774+
16775+
// Compute temp = kv * time_faaaa + prev_state
16776+
__m256 kv_time_vec = _mm256_mul_ps(kv_vec, time_faaaa_vec);
16777+
__m256 temp_vec = _mm256_add_ps(kv_time_vec, prev_state_vec);
16778+
16779+
// Update dst: dst += temp * r
16780+
__m256 result_vec = _mm256_mul_ps(temp_vec, r_vec);
16781+
dst_vec = _mm256_add_ps(dst_vec, result_vec);
16782+
_mm256_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
16783+
16784+
// Update state: state = prev_state * time_decay + kv
16785+
__m256 decay_state_vec = _mm256_mul_ps(prev_state_vec, time_decay_vec);
16786+
__m256 new_state_vec = _mm256_add_ps(decay_state_vec, kv_vec);
16787+
_mm256_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
16788+
}
16789+
16790+
// Handle remaining elements, this will not be used.
16791+
for (size_t j = vec_count * vec_size; j < head_size; j++) {
16792+
size_t t_h_j_offset = t_h_offset + j;
16793+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
16794+
16795+
float v_val = v[t_h_j_offset];
16796+
float kv_val = v_val * k_val;
16797+
float prev_state_val = state_prev[h_2d_i_j_offset];
16798+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
16799+
dst_data[t_h_j_offset] += temp_val * r_val;
16800+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16801+
}
16802+
}
16803+
}
16804+
}
16805+
16806+
#elif __AVX512F__
16807+
// AVX-512 uses 512-bit vectors = 16 float32
16808+
const int vec_size = 16;
16809+
const size_t vec_count = head_size / vec_size;
16810+
const size_t vec_remain = head_size % vec_size;
16811+
16812+
for (size_t t = 0; t < T; t++) {
16813+
size_t t_offset = t * t_stride;
16814+
size_t state_offset = head_size * C * (t / (T / n_seqs));
16815+
float * state_cur = state + state_offset;
16816+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16817+
16818+
for (size_t h = 0; h < H; h++) {
16819+
size_t h_offset = h * h_stride;
16820+
size_t t_h_offset = t_offset + h_offset;
16821+
size_t h_2d_offset = h * h_stride_2d;
16822+
16823+
for (size_t i = 0; i < head_size; i++) {
16824+
size_t t_h_i_offset = t_h_offset + i;
16825+
size_t h_i_offset = h_offset + i;
16826+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16827+
16828+
// Load scalar values
16829+
float k_val = k[t_h_i_offset];
16830+
float r_val = r[t_h_i_offset];
16831+
float time_faaaa_val = time_faaaa[h_i_offset];
16832+
float time_decay_val = time_decay[t_h_i_offset];
16833+
16834+
// Broadcast scalar values to ZMM registers (512-bit)
16835+
__m512 k_vec = _mm512_set1_ps(k_val);
16836+
__m512 r_vec = _mm512_set1_ps(r_val);
16837+
__m512 time_faaaa_vec = _mm512_set1_ps(time_faaaa_val);
16838+
__m512 time_decay_vec = _mm512_set1_ps(time_decay_val);
16839+
16840+
// Use prefetch to reduce cache misses
16841+
#define PREFETCH_OFFSET 2
16842+
if (i + PREFETCH_OFFSET < head_size) {
16843+
_mm_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], _MM_HINT_T0);
16844+
_mm_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], _MM_HINT_T0);
16845+
}
16846+
16847+
// Vector processing for chunks of 16 floats
16848+
for (size_t j = 0; j < vec_count; j++) {
16849+
size_t base_j = j * vec_size;
16850+
size_t t_h_j_offset = t_h_offset + base_j;
16851+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16852+
16853+
// Load 16 elements at once
16854+
__m512 v_vec = _mm512_loadu_ps(&v[t_h_j_offset]);
16855+
__m512 prev_state_vec = _mm512_loadu_ps(&state_prev[h_2d_i_j_offset]);
16856+
__m512 dst_vec = _mm512_loadu_ps(&dst_data[t_h_j_offset]);
16857+
16858+
// Compute kv = v * k using FMA
16859+
__m512 kv_vec = _mm512_mul_ps(v_vec, k_vec);
16860+
16861+
// Compute temp = kv * time_faaaa + prev_state using FMA
16862+
__m512 temp_vec = _mm512_fmadd_ps(kv_vec, time_faaaa_vec, prev_state_vec);
16863+
16864+
// Update dst: dst += temp * r using FMA
16865+
dst_vec = _mm512_fmadd_ps(temp_vec, r_vec, dst_vec);
16866+
_mm512_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
16867+
16868+
// Update state: state = prev_state * time_decay + kv using FMA
16869+
__m512 new_state_vec = _mm512_fmadd_ps(prev_state_vec, time_decay_vec, kv_vec);
16870+
_mm512_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
16871+
}
16872+
16873+
// Handle remaining elements, this will not be used.
16874+
for (size_t j = vec_count * vec_size; j < head_size; j++) {
16875+
size_t t_h_j_offset = t_h_offset + j;
16876+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
16877+
16878+
float v_val = v[t_h_j_offset];
16879+
float kv_val = v_val * k_val;
16880+
float prev_state_val = state_prev[h_2d_i_j_offset];
16881+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
16882+
dst_data[t_h_j_offset] += temp_val * r_val;
16883+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16884+
}
16885+
}
16886+
}
16887+
}
16888+
16889+
16890+
#elif __ARM_FEATURE_SVE
16891+
// Get vector length for this CPU
16892+
const size_t vec_size = svcntw();
16893+
16894+
for (size_t t = 0; t < T; t++) {
16895+
size_t t_offset = t * t_stride;
16896+
size_t state_offset = head_size * C * (t / (T / n_seqs));
16897+
float * state_cur = state + state_offset;
16898+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16899+
16900+
for (size_t h = 0; h < H; h++) {
16901+
size_t h_offset = h * h_stride;
16902+
size_t t_h_offset = t_offset + h_offset;
16903+
size_t h_2d_offset = h * h_stride_2d;
16904+
16905+
for (size_t i = 0; i < head_size; i++) {
16906+
size_t t_h_i_offset = t_h_offset + i;
16907+
size_t h_i_offset = h_offset + i;
16908+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16909+
16910+
float k_val = k[t_h_i_offset];
16911+
float r_val = r[t_h_i_offset];
16912+
float time_faaaa_val = time_faaaa[h_i_offset];
16913+
float time_decay_val = time_decay[t_h_i_offset];
16914+
16915+
// Create predicate for active lanes
16916+
svbool_t pg = svwhilelt_b32(0, head_size);
16917+
16918+
// Process vectors until done
16919+
size_t j = 0;
16920+
while (svptest_first(svptrue_b32(), pg)) {
16921+
size_t t_h_j_offset = t_h_offset + j;
16922+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
16923+
16924+
// Load vectors
16925+
svfloat32_t v_vec = svld1_f32(pg, &v[t_h_j_offset]);
16926+
svfloat32_t prev_state_vec = svld1_f32(pg, &state_prev[h_2d_i_j_offset]);
16927+
svfloat32_t dst_vec = svld1_f32(pg, &dst_data[t_h_j_offset]);
16928+
16929+
// Compute kv = v * k
16930+
svfloat32_t kv_vec = svmul_n_f32_x(pg, v_vec, k_val);
16931+
16932+
// Compute temp = kv * time_faaaa + prev_state
16933+
svfloat32_t temp_vec = svmad_n_f32_x(pg, kv_vec, time_faaaa_val, prev_state_vec);
16934+
16935+
// Update dst: dst += temp * r
16936+
svfloat32_t result_vec = svmad_n_f32_x(pg, temp_vec, r_val, dst_vec);
16937+
svst1_f32(pg, &dst_data[t_h_j_offset], result_vec);
16938+
16939+
// Update state: state = prev_state * time_decay + kv
16940+
svfloat32_t new_state_vec = svmad_n_f32_x(pg, prev_state_vec, time_decay_val, kv_vec);
16941+
svst1_f32(pg, &state_cur[h_2d_i_j_offset], new_state_vec);
16942+
16943+
j += vec_size;
16944+
pg = svwhilelt_b32(j, head_size);
16945+
}
16946+
}
16947+
}
16948+
}
16949+
16950+
#elif __ARM_NEON
16951+
// NEON uses 128-bit vectors = 4 float32s
16952+
const int vec_size = 4;
16953+
const size_t vec_count = head_size / vec_size;
16954+
16955+
for (size_t t = 0; t < T; t++) {
16956+
size_t t_offset = t * t_stride;
16957+
size_t state_offset = head_size * C * (t / (T / n_seqs));
16958+
float * state_cur = state + state_offset;
16959+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16960+
16961+
for (size_t h = 0; h < H; h++) {
16962+
size_t h_offset = h * h_stride;
16963+
size_t t_h_offset = t_offset + h_offset;
16964+
size_t h_2d_offset = h * h_stride_2d;
16965+
16966+
for (size_t i = 0; i < head_size; i++) {
16967+
size_t t_h_i_offset = t_h_offset + i;
16968+
size_t h_i_offset = h_offset + i;
16969+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16970+
16971+
float k_val = k[t_h_i_offset];
16972+
float r_val = r[t_h_i_offset];
16973+
float time_faaaa_val = time_faaaa[h_i_offset];
16974+
float time_decay_val = time_decay[t_h_i_offset];
16975+
16976+
// Broadcast scalar values to vectors
16977+
float32x4_t k_vec = vdupq_n_f32(k_val);
16978+
float32x4_t r_vec = vdupq_n_f32(r_val);
16979+
float32x4_t time_faaaa_vec = vdupq_n_f32(time_faaaa_val);
16980+
float32x4_t time_decay_vec = vdupq_n_f32(time_decay_val);
16981+
16982+
// Use prefetch to reduce cache misses
16983+
#ifdef __ARM_FEATURE_PREFETCH
16984+
#define PREFETCH_OFFSET 2
16985+
if (i + PREFETCH_OFFSET < head_size) {
16986+
__builtin_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], 0, 3);
16987+
__builtin_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], 0, 3);
16988+
}
16989+
#endif
16990+
16991+
// Vector processing for chunks of 4 floats
16992+
for (size_t j = 0; j < vec_count; j++) {
16993+
size_t base_j = j * vec_size;
16994+
size_t t_h_j_offset = t_h_offset + base_j;
16995+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
16996+
16997+
// Load 4 elements at once
16998+
float32x4_t v_vec = vld1q_f32(&v[t_h_j_offset]);
16999+
float32x4_t prev_state_vec = vld1q_f32(&state_prev[h_2d_i_j_offset]);
17000+
float32x4_t dst_vec = vld1q_f32(&dst_data[t_h_j_offset]);
17001+
17002+
// Compute kv = v * k
17003+
float32x4_t kv_vec = vmulq_f32(v_vec, k_vec);
17004+
17005+
// Compute temp = kv * time_faaaa + prev_state using FMA
17006+
#ifdef __ARM_FEATURE_FMA
17007+
float32x4_t temp_vec = vfmaq_f32(prev_state_vec, kv_vec, time_faaaa_vec);
17008+
// Update dst: dst += temp * r
17009+
dst_vec = vfmaq_f32(dst_vec, temp_vec, r_vec);
17010+
// Update state: state = prev_state * time_decay + kv
17011+
float32x4_t new_state_vec = vfmaq_f32(kv_vec, prev_state_vec, time_decay_vec);
17012+
#else
17013+
float32x4_t kv_time = vmulq_f32(kv_vec, time_faaaa_vec);
17014+
float32x4_t temp_vec = vaddq_f32(kv_time, prev_state_vec);
17015+
float32x4_t result_vec = vmulq_f32(temp_vec, r_vec);
17016+
dst_vec = vaddq_f32(dst_vec, result_vec);
17017+
float32x4_t decay_state_vec = vmulq_f32(prev_state_vec, time_decay_vec);
17018+
float32x4_t new_state_vec = vaddq_f32(decay_state_vec, kv_vec);
17019+
#endif
17020+
17021+
vst1q_f32(&dst_data[t_h_j_offset], dst_vec);
17022+
vst1q_f32(&state_cur[h_2d_i_j_offset], new_state_vec);
17023+
}
17024+
17025+
// Handle remaining elements
17026+
for (size_t j = vec_count * vec_size; j < head_size; j++) {
17027+
size_t t_h_j_offset = t_h_offset + j;
17028+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
17029+
17030+
float v_val = v[t_h_j_offset];
17031+
float kv_val = v_val * k_val;
17032+
float prev_state_val = state_prev[h_2d_i_j_offset];
17033+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
17034+
dst_data[t_h_j_offset] += temp_val * r_val;
17035+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17036+
}
17037+
}
17038+
}
17039+
}
17040+
17041+
17042+
#else
1672917043
// basically fused operations:
1673017044
// dst = r @ (time_faaaa * (k @ v) + state),
1673117045
// state = time_decay * state + (k @ v),
@@ -16765,7 +17079,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1676517079
}
1676617080
}
1676717081
}
17082+
1676817083
}
17084+
#endif
1676917085
}
1677017086

1677117087

0 commit comments

Comments
 (0)