Skip to content

Commit 2fd1ce1

Browse files
committed
rwkv6: update cuda file name
1 parent ae39cb0 commit 2fd1ce1

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "ggml-cuda/tsembd.cuh"
3737
#include "ggml-cuda/unary.cuh"
3838
#include "ggml-cuda/upscale.cuh"
39-
#include "ggml-cuda/rwkv-wkv.cuh"
39+
#include "ggml-cuda/wkv6.cuh"
4040

4141
#include <algorithm>
4242
#include <array>

ggml/src/ggml-cuda/rwkv-wkv.cu renamed to ggml/src/ggml-cuda/wkv6.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "common.cuh"
2-
#include "rwkv-wkv.cuh"
2+
#include "wkv6.cuh"
33

44
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
55
const int tid = threadIdx.x;
File renamed without changes.

ggml/src/ggml.c

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3077,7 +3077,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
30773077
"WIN_UNPART",
30783078
"GET_REL_POS",
30793079
"ADD_REL_POS",
3080-
"RWKV_WKV",
3080+
"RWKV_WKV6",
30813081

30823082
"UNARY",
30833083

@@ -16709,11 +16709,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1670916709
float * dst_data = (float *) dst->data;
1671016710
float * state = ((float *) dst->data) + C * T;
1671116711

16712-
if (params->ith != 0) {
16712+
if ((size_t)params->ith >= H) {
1671316713
return;
1671416714
}
1671516715

16716-
memset(dst_data, 0, T * C * sizeof(float));
16716+
size_t h_start = (H * params->ith) / params->nth;
16717+
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
16718+
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
1671716719

1671816720
float * k = (float *) dst->src[0]->data;
1671916721
float * v = (float *) dst->src[1]->data;
@@ -16726,6 +16728,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1672616728
size_t h_stride = C / H;
1672716729
size_t h_stride_2d = head_size * head_size;
1672816730

16731+
if (params->ith == 0) {
16732+
memset(dst_data, 0, T * C * sizeof(float));
16733+
}
16734+
ggml_barrier(params->threadpool);
16735+
16736+
16737+
1672916738
#ifdef __AVX2__
1673016739
// AVX2 uses 256-bit vectors = 8 float32
1673116740
const int vec_size = 8;
@@ -16737,7 +16746,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1673716746
float * state_cur = state + state_offset;
1673816747
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1673916748

16740-
for (size_t h = 0; h < H; h++) {
16749+
for (size_t h = h_start; h < h_end; h++) {
1674116750
size_t h_offset = h * h_stride;
1674216751
size_t t_h_offset = t_offset + h_offset;
1674316752
size_t h_2d_offset = h * h_stride_2d;
@@ -16815,7 +16824,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1681516824
float * state_cur = state + state_offset;
1681616825
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1681716826

16818-
for (size_t h = 0; h < H; h++) {
16827+
for (size_t h = h_start; h < h_end; h++) {
1681916828
size_t h_offset = h * h_stride;
1682016829
size_t t_h_offset = t_offset + h_offset;
1682116830
size_t h_2d_offset = h * h_stride_2d;
@@ -16897,7 +16906,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1689716906
float * state_cur = state + state_offset;
1689816907
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1689916908

16900-
for (size_t h = 0; h < H; h++) {
16909+
for (size_t h = h_start; h < h_end; h++) {
1690116910
size_t h_offset = h * h_stride;
1690216911
size_t t_h_offset = t_offset + h_offset;
1690316912
size_t h_2d_offset = h * h_stride_2d;
@@ -16958,7 +16967,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1695816967
float * state_cur = state + state_offset;
1695916968
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1696016969

16961-
for (size_t h = 0; h < H; h++) {
16970+
for (size_t h = h_start; h < h_end; h++) {
1696216971
size_t h_offset = h * h_stride;
1696316972
size_t t_h_offset = t_offset + h_offset;
1696416973
size_t h_2d_offset = h * h_stride_2d;
@@ -17050,7 +17059,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1705017059
float * state_cur = state + state_offset;
1705117060
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1705217061

17053-
for (size_t h = 0; h < H; h++) {
17062+
for (size_t h = h_start; h < h_end; h++) {
1705417063
size_t h_offset = h * h_stride;
1705517064
size_t t_h_offset = t_offset + h_offset;
1705617065
size_t h_2d_offset = h * h_stride_2d;

0 commit comments

Comments
 (0)