1
1
#include < sycl/sycl.hpp>
2
- #include " wkv6 .hpp"
2
+ #include " wkv .hpp"
3
3
4
4
constexpr int WKV_BLOCK_SIZE = 64 ; // Matching CUDA_WKV_BLOCK_SIZE
5
5
6
6
// Helper function for the main kernel
7
- static void rwkv_wkv_f32_kernel (
7
+ static void rwkv_wkv6_f32_kernel (
8
8
const int B, const int T, const int C, const int H,
9
9
const float * k, const float * v, const float * r,
10
10
const float * tf, const float * td, const float * s,
@@ -95,6 +95,88 @@ static void rwkv_wkv_f32_kernel(
95
95
}
96
96
}
97
97
98
+ static void rwkv_wkv7_f32_kernel (
99
+ const int B, const int T, const int C, const int H,
100
+ const float * r, const float * w, const float * k, const float * v,
101
+ const float * a, const float * b, const float * s,
102
+ float * dst, const sycl::nd_item<3 >& item_ct1, float * shared_mem) {
103
+
104
+ const int tid = item_ct1.get_local_id (2 );
105
+ const int bid = item_ct1.get_group (2 );
106
+
107
+ const int head_size = WKV_BLOCK_SIZE;
108
+ const int batch_i = bid / H;
109
+ const int head_i = bid % H;
110
+ const int state_size = C * head_size;
111
+ const int n_seq_tokens = T / B;
112
+
113
+ float * _r = shared_mem;
114
+ float * _w = _r + head_size;
115
+ float * _k = _w + head_size;
116
+ float * _a = _k + head_size;
117
+ float * _b = _a + head_size;
118
+
119
+ float state[WKV_BLOCK_SIZE];
120
+
121
+ #pragma unroll
122
+ for (int i = 0 ; i < head_size; i++) {
123
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
124
+ }
125
+
126
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
127
+ t < (batch_i + 1 ) * n_seq_tokens * C + head_i * head_size + tid;
128
+ t += C) {
129
+
130
+ item_ct1.barrier (sycl::access::fence_space::local_space);
131
+
132
+ _r[tid] = r[t];
133
+ _w[tid] = w[t];
134
+ _k[tid] = k[t];
135
+ _a[tid] = a[t];
136
+ _b[tid] = b[t];
137
+
138
+ item_ct1.barrier (sycl::access::fence_space::local_space);
139
+
140
+ const float _v = v[t];
141
+ float y = 0 , sa = 0 ;
142
+ sycl::float4 a4, s4;
143
+
144
+ #pragma unroll
145
+ for (int j = 0 ; j < head_size; j += 4 ) {
146
+ a4 = sycl::float4 (_a[j], _a[j+1 ], _a[j+2 ], _a[j+3 ]);
147
+ s4 = sycl::float4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
148
+ sa += sycl::dot (a4, s4);
149
+ }
150
+
151
+ sycl::float4 r4, w4, k4, b4;
152
+ #pragma unroll
153
+ for (int j = 0 ; j < head_size; j += 4 ) {
154
+ r4 = sycl::float4 (_r[j], _r[j+1 ], _r[j+2 ], _r[j+3 ]);
155
+ w4 = sycl::float4 (_w[j], _w[j+1 ], _w[j+2 ], _w[j+3 ]);
156
+ k4 = sycl::float4 (_k[j], _k[j+1 ], _k[j+2 ], _k[j+3 ]);
157
+ b4 = sycl::float4 (_b[j], _b[j+1 ], _b[j+2 ], _b[j+3 ]);
158
+ s4 = sycl::float4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
159
+
160
+ sycl::float4 kv4 = k4 * _v;
161
+
162
+ s4 = s4 * w4 + kv4 + sa * b4;
163
+ y += sycl::dot (r4, s4);
164
+
165
+ state[j] = s4.x ();
166
+ state[j+1 ] = s4.y ();
167
+ state[j+2 ] = s4.z ();
168
+ state[j+3 ] = s4.w ();
169
+ }
170
+
171
+ dst[t] = y;
172
+ }
173
+
174
+ #pragma unroll
175
+ for (int i = 0 ; i < head_size; i++) {
176
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
177
+ }
178
+ }
179
+
98
180
void ggml_sycl_op_rwkv_wkv6 (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
99
181
100
182
const ggml_tensor *src0 = dst->src [0 ];
@@ -131,7 +213,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
131
213
cgh.parallel_for (
132
214
sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
133
215
[=](sycl::nd_item<3 > item_ct1) {
134
- rwkv_wkv_f32_kernel (
216
+ rwkv_wkv6_f32_kernel (
135
217
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
136
218
item_ct1, (float *)shared_mem_acc.get_multi_ptr <sycl::access::decorated::no>().get ()
137
219
);
@@ -141,3 +223,51 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
141
223
GGML_UNUSED (src0);
142
224
GGML_UNUSED (src1);
143
225
}
226
+
227
+ void ggml_sycl_op_rwkv_wkv7 (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
228
+
229
+ const ggml_tensor *src0 = dst->src [0 ];
230
+ const ggml_tensor *src1 = dst->src [1 ];
231
+
232
+ const float * r_d = (const float *)dst->src [0 ]->data ;
233
+ const float * w_d = (const float *)dst->src [1 ]->data ;
234
+ const float * k_d = (const float *)dst->src [2 ]->data ;
235
+ const float * v_d = (const float *)dst->src [3 ]->data ;
236
+ const float * a_d = (const float *)dst->src [4 ]->data ;
237
+ const float * b_d = (const float *)dst->src [5 ]->data ;
238
+ const float * s_d = (const float *)dst->src [6 ]->data ;
239
+ float * dst_d = (float *)dst->data ;
240
+
241
+ const int64_t B = dst->src [6 ]->ne [1 ];
242
+ const int64_t T = dst->src [0 ]->ne [2 ];
243
+ const int64_t C = dst->ne [0 ];
244
+ const int64_t H = dst->src [0 ]->ne [1 ];
245
+
246
+ GGML_ASSERT (dst->src [6 ]->type == GGML_TYPE_F32);
247
+ GGML_ASSERT (C % H == 0 );
248
+ GGML_ASSERT (C / H == WKV_BLOCK_SIZE);
249
+
250
+ dpct::queue_ptr stream = ctx.stream ();
251
+
252
+ // Calculate execution configuration
253
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 5 * sizeof (float ); // For r, w, k, a, b
254
+ sycl::range<3 > block_dims (1 , 1 , C / H);
255
+ sycl::range<3 > grid_dims (1 , 1 , B * H);
256
+
257
+ // Submit kernel
258
+ stream->submit ([&](sycl::handler& cgh) {
259
+ sycl::local_accessor<float , 1 > shared_mem_acc (shared_mem_size, cgh);
260
+
261
+ cgh.parallel_for (
262
+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
263
+ [=](sycl::nd_item<3 > item_ct1) {
264
+ rwkv_wkv7_f32_kernel (
265
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
266
+ item_ct1, (float *)shared_mem_acc.get_multi_ptr <sycl::access::decorated::no>().get ()
267
+ );
268
+ });
269
+ });
270
+
271
+ GGML_UNUSED (src0);
272
+ GGML_UNUSED (src1);
273
+ }
0 commit comments