1
+ #include < sycl/sycl.hpp>
2
+ #include " wkv6.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64 ; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ static void rwkv_wkv_f32_kernel (
8
+ const int B, const int T, const int C, const int H,
9
+ const float * k, const float * v, const float * r,
10
+ const float * tf, const float * td, const float * s,
11
+ float * dst, const sycl::nd_item<3 >& item_ct1, float * shared_mem) {
12
+
13
+ const int tid = item_ct1.get_local_id (2 );
14
+ const int bid = item_ct1.get_group (2 );
15
+
16
+ const int head_size = WKV_BLOCK_SIZE;
17
+ const int batch_i = bid / H;
18
+ const int head_i = bid % H;
19
+ const int state_size = C * head_size;
20
+ const int n_seq_tokens = T / B;
21
+
22
+ // Set up shared memory pointers
23
+ float * _k = shared_mem;
24
+ float * _r = _k + head_size;
25
+ float * _tf = _r + head_size;
26
+ float * _td = _tf + head_size;
27
+
28
+ // Local state array
29
+ float state[WKV_BLOCK_SIZE];
30
+
31
+ // Load initial state
32
+ #pragma unroll
33
+ for (int i = 0 ; i < head_size; i++) {
34
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
+ }
36
+
37
+ // Sync threads before shared memory operations
38
+ item_ct1.barrier (sycl::access::fence_space::local_space);
39
+
40
+ // Load time-mixing parameters
41
+ _tf[tid] = tf[head_i * head_size + tid];
42
+ item_ct1.barrier (sycl::access::fence_space::local_space);
43
+
44
+ // Main sequence processing loop
45
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
+ t < (batch_i + 1 ) * n_seq_tokens * C + head_i * head_size + tid;
47
+ t += C) {
48
+
49
+ item_ct1.barrier (sycl::access::fence_space::local_space);
50
+
51
+ // Load current timestep data to shared memory
52
+ _k[tid] = k[t];
53
+ _r[tid] = r[t];
54
+ _td[tid] = td[t];
55
+
56
+ item_ct1.barrier (sycl::access::fence_space::local_space);
57
+
58
+ const float _v = v[t];
59
+ float y = 0 ;
60
+
61
+ // Process in chunks of 4 for better vectorization
62
+ #pragma unroll
63
+ for (int j = 0 ; j < head_size; j += 4 ) {
64
+ // Load data in vec4 chunks
65
+ sycl::float4 k4 (_k[j], _k[j+1 ], _k[j+2 ], _k[j+3 ]);
66
+ sycl::float4 r4 (_r[j], _r[j+1 ], _r[j+2 ], _r[j+3 ]);
67
+ sycl::float4 tf4 (_tf[j], _tf[j+1 ], _tf[j+2 ], _tf[j+3 ]);
68
+ sycl::float4 td4 (_td[j], _td[j+1 ], _td[j+2 ], _td[j+3 ]);
69
+ sycl::float4 s4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
70
+
71
+ // Compute key-value product
72
+ sycl::float4 kv4 = k4 * _v;
73
+
74
+ // Accumulate weighted sum
75
+ y += sycl::dot (r4, tf4 * kv4 + s4);
76
+
77
+ // Update state
78
+ s4 = s4 * td4 + kv4;
79
+
80
+ // Store updated state
81
+ state[j] = s4.x ();
82
+ state[j+1 ] = s4.y ();
83
+ state[j+2 ] = s4.z ();
84
+ state[j+3 ] = s4.w ();
85
+ }
86
+
87
+ dst[t] = y;
88
+ }
89
+
90
+ // Save final state
91
+ #pragma unroll
92
+ for (int i = 0 ; i < head_size; i++) {
93
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
94
+ }
95
+ }
96
+
97
+ void ggml_sycl_op_rwkv_wkv6 (ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
98
+ const ggml_tensor* src1, ggml_tensor* dst) {
99
+
100
+ const float * k_d = (const float *)dst->src [0 ]->data ;
101
+ const float * v_d = (const float *)dst->src [1 ]->data ;
102
+ const float * r_d = (const float *)dst->src [2 ]->data ;
103
+ const float * tf_d = (const float *)dst->src [3 ]->data ;
104
+ const float * td_d = (const float *)dst->src [4 ]->data ;
105
+ const float * s_d = (const float *)dst->src [5 ]->data ;
106
+ float * dst_d = (float *)dst->data ;
107
+
108
+ const int64_t B = dst->src [5 ]->ne [1 ];
109
+ const int64_t T = dst->src [0 ]->ne [3 ];
110
+ const int64_t C = dst->ne [0 ];
111
+ const int64_t H = dst->src [0 ]->ne [2 ];
112
+
113
+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
114
+ GGML_ASSERT (C % H == 0 );
115
+ GGML_ASSERT (C / H == WKV_BLOCK_SIZE);
116
+
117
+ dpct::queue_ptr stream = ctx.stream ();
118
+
119
+ // Calculate execution configuration
120
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof (float ); // For k, r, tf, td
121
+ sycl::range<3 > block_dims (1 , 1 , C / H);
122
+ sycl::range<3 > grid_dims (1 , 1 , B * H);
123
+
124
+ // Submit kernel
125
+ stream->submit ([&](sycl::handler& cgh) {
126
+ sycl::local_accessor<float , 1 > shared_mem_acc (shared_mem_size, cgh);
127
+
128
+ cgh.parallel_for (
129
+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
130
+ [=](sycl::nd_item<3 > item_ct1) {
131
+ rwkv_wkv_f32_kernel (
132
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
133
+ item_ct1, shared_mem_acc.get_pointer ()
134
+ );
135
+ });
136
+ });
137
+ }
0 commit comments