@@ -176,17 +176,32 @@ void resize_sdpa_out(
176
176
graph->get_tensor (out)->virtual_resize (graph->sizes_of (q_projected));
177
177
}
178
178
179
- void sdpa_with_kv_cache_impl (
180
- ComputeGraph& graph,
181
- const std::vector<ValueRef>& args) {
179
+ void update_cache_impl (ComputeGraph& graph, const std::vector<ValueRef>& args) {
180
+ int arg_idx = 0 ;
181
+ const ValueRef value = args[arg_idx++];
182
+ const ValueRef cache = args[arg_idx++];
183
+ const ValueRef input_pos_symint = args[arg_idx++];
184
+ const ValueRef out = args[arg_idx++];
185
+
186
+ // Unused variables
187
+ (void )out;
188
+
189
+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , value) == 1 );
190
+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , cache) == 1 );
191
+ VK_CHECK_COND (
192
+ graph.size_at <int32_t >(-1 , value) == graph.size_at <int32_t >(-1 , cache));
193
+ VK_CHECK_COND (
194
+ graph.size_at <int32_t >(-2 , value) == graph.size_at <int32_t >(-2 , cache));
195
+
196
+ add_kv_cache_update_node (graph, input_pos_symint, value, cache);
197
+ }
198
+
199
+ void sdpa_impl (ComputeGraph& graph, const std::vector<ValueRef>& args) {
182
200
int arg_idx = 0 ;
183
201
const ValueRef q_projected = args[arg_idx++];
184
- const ValueRef k_projected = args[arg_idx++];
185
- const ValueRef v_projected = args[arg_idx++];
186
- const ValueRef k_cache_data = args[arg_idx++];
187
- const ValueRef v_cache_data = args[arg_idx++];
202
+ const ValueRef k_cache = args[arg_idx++];
203
+ const ValueRef v_cache = args[arg_idx++];
188
204
const ValueRef input_pos_symint = args[arg_idx++];
189
- const ValueRef sequence_len = args[arg_idx++];
190
205
const ValueRef attn_mask = args[arg_idx++];
191
206
const ValueRef dropout_p = args[arg_idx++];
192
207
const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195
210
// Output tensors
196
211
const ValueRef out = args[arg_idx++];
197
212
198
- // Unused variables
199
- (void )sequence_len;
200
-
201
213
// Batches must be 1
202
214
VK_CHECK_COND (graph.size_at <int32_t >(-4 , q_projected) == 1 );
203
- VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_projected ) == 1 );
204
- VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_projected ) == 1 );
215
+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_cache ) == 1 );
216
+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_cache ) == 1 );
205
217
// k and v projected must have the same shape
206
- VK_CHECK_COND (graph.sizes_of (k_projected ) == graph.sizes_of (v_projected ));
218
+ VK_CHECK_COND (graph.sizes_of (k_cache ) == graph.sizes_of (v_cache ));
207
219
// head dim must match between tensors
208
220
VK_CHECK_COND (
209
221
graph.size_at <int32_t >(-1 , q_projected) ==
210
- graph.size_at <int32_t >(-1 , k_projected ));
222
+ graph.size_at <int32_t >(-1 , k_cache ));
211
223
// All tensors must have the packed dim be the width (head) dimension
212
224
VK_CHECK_COND (graph.packed_dim_of (q_projected) == WHCN::kWidthDim );
213
- VK_CHECK_COND (graph.packed_dim_of (k_projected ) == WHCN::kWidthDim );
214
- VK_CHECK_COND (graph.packed_dim_of (v_projected ) == WHCN::kWidthDim );
225
+ VK_CHECK_COND (graph.packed_dim_of (k_cache ) == WHCN::kWidthDim );
226
+ VK_CHECK_COND (graph.packed_dim_of (v_cache ) == WHCN::kWidthDim );
215
227
// Some variables are not supported yet
216
228
VK_CHECK_COND (
217
229
graph.val_is_none (dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222
234
graph.val_is_none (is_causal) || graph.extract_scalar <bool >(is_causal));
223
235
VK_CHECK_COND (graph.val_is_none (attn_mask));
224
236
225
- const ValueRef k_cache =
226
- prepack_standard_like (graph, k_cache_data, q_projected);
227
- const ValueRef v_cache =
228
- prepack_standard_like (graph, v_cache_data, q_projected);
229
-
230
237
const int32_t max_seq_len = graph.size_at <int32_t >(1 , k_cache);
231
238
232
- add_kv_cache_update_node (graph, input_pos_symint, k_projected, k_cache);
233
- add_kv_cache_update_node (graph, input_pos_symint, v_projected, v_cache);
234
-
235
239
// Slice caches from 0 to input_pos + sequence_len
236
240
const ValueRef k_cache_sliced = graph.add_tensor_view (k_cache);
237
241
const ValueRef v_cache_sliced = graph.add_tensor_view (v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257
261
258
262
// Repeat interleave
259
263
const int64_t num_heads = graph.size_at <int64_t >(2 , q_projected);
260
- const int64_t num_kv_heads = graph.size_at <int64_t >(2 , k_projected );
264
+ const int64_t num_kv_heads = graph.size_at <int64_t >(2 , k_cache );
261
265
262
266
const ValueRef num_repeats =
263
267
graph.add_scalar <int64_t >(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331
335
new ExecuteNode (resize_sdpa_out, {q_projected, out}));
332
336
}
333
337
338
+ void sdpa_with_kv_cache_impl (
339
+ ComputeGraph& graph,
340
+ const std::vector<ValueRef>& args) {
341
+ int arg_idx = 0 ;
342
+ const ValueRef q_projected = args[arg_idx++];
343
+ const ValueRef k_projected = args[arg_idx++];
344
+ const ValueRef v_projected = args[arg_idx++];
345
+ const ValueRef k_cache_data = args[arg_idx++];
346
+ const ValueRef v_cache_data = args[arg_idx++];
347
+ const ValueRef input_pos_symint = args[arg_idx++];
348
+ const ValueRef sequence_len = args[arg_idx++];
349
+ const ValueRef attn_mask = args[arg_idx++];
350
+ const ValueRef dropout_p = args[arg_idx++];
351
+ const ValueRef is_causal = args[arg_idx++];
352
+ const ValueRef scale = args[arg_idx++];
353
+
354
+ // Output tensors
355
+ const ValueRef out = args[arg_idx++];
356
+
357
+ (void )sequence_len;
358
+
359
+ const ValueRef k_cache =
360
+ prepack_standard_like (graph, k_cache_data, q_projected);
361
+ const ValueRef v_cache =
362
+ prepack_standard_like (graph, v_cache_data, q_projected);
363
+
364
+ update_cache_impl (graph, {k_projected, k_cache, input_pos_symint, -1 });
365
+ update_cache_impl (graph, {v_projected, v_cache, input_pos_symint, -1 });
366
+
367
+ sdpa_impl (
368
+ graph,
369
+ {q_projected,
370
+ k_cache,
371
+ v_cache,
372
+ input_pos_symint,
373
+ attn_mask,
374
+ dropout_p,
375
+ is_causal,
376
+ scale,
377
+ out});
378
+ }
379
+
334
380
REGISTER_OPERATORS {
335
381
VK_REGISTER_OP (sdpa_with_kv_cache.default , sdpa_with_kv_cache_impl);
382
+ VK_REGISTER_OP (update_cache.default , update_cache_impl);
383
+ VK_REGISTER_OP (llama.custom_sdpa .default , sdpa_impl);
336
384
}
337
385
338
386
} // namespace vkcompute
0 commit comments