Skip to content

Commit cd95c74

Browse files
committed
[ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and KV cache update operator
## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/) [ghstack-poisoned]
1 parent afc5a50 commit cd95c74

File tree

4 files changed

+99
-31
lines changed

4 files changed

+99
-31
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6060
)
6161
# This pass assumes that the SpecPropPass() has already been applied
6262
assert "spec" in node.meta
63+
# Mutable buffers will not be marked as constant, but it might as well be
64+
# for the purposes of memory planning. Mark it as a constant tensor so that
65+
# it is handled correctly by the memory planning pass.
66+
if not node.meta["spec"].const:
67+
assert is_param_node(program, node)
68+
node.meta["spec"].const = True
6369
# Validate that the original node is marked as a constant. Constant tensors
6470
# do not participate in memory planning.
6571
assert node.meta["spec"].const
@@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6874
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
6975
# memory object.
7076
prepack_node.meta["spec"].mem_obj_id = -1
71-
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
77+
node.replace_all_uses_with(
78+
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
79+
)
7280

7381
program.graph.eliminate_dead_code()
7482
return program

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
220220

221221
# noqa
222222
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
223-
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
224-
225-
for node in sorted_nodes:
223+
for node in graph_module.graph.nodes:
226224
if not self.should_annotate(node) or self.should_delay_annotation(node):
227225
continue
228226

backends/vulkan/op_registry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):
478478

479479

480480
@update_features("llama::sdpa_with_kv_cache")
481-
def register_sdpa_op(features: OpFeatures):
481+
def register_sdpa_with_kv_cache_op(features: OpFeatures):
482482
features.texture_impl = TextureImplFeatures(
483483
valid_packed_dims={PackedDim.WIDTH},
484484
)
@@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures):
489489
return features
490490

491491

492+
# TODO(ssjia) allow registration after remove assertions pass is implemented
493+
# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
494+
def register_sdpa_ops(features: OpFeatures):
495+
features.texture_impl = TextureImplFeatures(
496+
valid_packed_dims={PackedDim.WIDTH},
497+
)
498+
features.resize_fn = False
499+
features.buffer_impl = False
500+
features.texture_impl = TextureImplFeatures(
501+
valid_packed_dims={PackedDim.WIDTH},
502+
)
503+
return features
504+
505+
492506
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
493507
def register_rotary_emb_op(features: OpFeatures):
494508
features.texture_impl = TextureImplFeatures(

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,32 @@ void resize_sdpa_out(
176176
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
177177
}
178178

179-
void sdpa_with_kv_cache_impl(
180-
ComputeGraph& graph,
181-
const std::vector<ValueRef>& args) {
179+
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
180+
int arg_idx = 0;
181+
const ValueRef value = args[arg_idx++];
182+
const ValueRef cache = args[arg_idx++];
183+
const ValueRef input_pos_symint = args[arg_idx++];
184+
const ValueRef out = args[arg_idx++];
185+
186+
// Unused variables
187+
(void)out;
188+
189+
VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
190+
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
191+
VK_CHECK_COND(
192+
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
193+
VK_CHECK_COND(
194+
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));
195+
196+
add_kv_cache_update_node(graph, input_pos_symint, value, cache);
197+
}
198+
199+
void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
182200
int arg_idx = 0;
183201
const ValueRef q_projected = args[arg_idx++];
184-
const ValueRef k_projected = args[arg_idx++];
185-
const ValueRef v_projected = args[arg_idx++];
186-
const ValueRef k_cache_data = args[arg_idx++];
187-
const ValueRef v_cache_data = args[arg_idx++];
202+
const ValueRef k_cache = args[arg_idx++];
203+
const ValueRef v_cache = args[arg_idx++];
188204
const ValueRef input_pos_symint = args[arg_idx++];
189-
const ValueRef sequence_len = args[arg_idx++];
190205
const ValueRef attn_mask = args[arg_idx++];
191206
const ValueRef dropout_p = args[arg_idx++];
192207
const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195210
// Output tensors
196211
const ValueRef out = args[arg_idx++];
197212

198-
// Unused variables
199-
(void)sequence_len;
200-
201213
// Batches must be 1
202214
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
203-
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
204-
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
215+
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
216+
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
205217
// k and v projected must have the same shape
206-
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
218+
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
207219
// head dim must match between tensors
208220
VK_CHECK_COND(
209221
graph.size_at<int32_t>(-1, q_projected) ==
210-
graph.size_at<int32_t>(-1, k_projected));
222+
graph.size_at<int32_t>(-1, k_cache));
211223
// All tensors must have the packed dim be the width (head) dimension
212224
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
213-
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
214-
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
225+
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
226+
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
215227
// Some variables are not supported yet
216228
VK_CHECK_COND(
217229
graph.val_is_none(dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222234
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
223235
VK_CHECK_COND(graph.val_is_none(attn_mask));
224236

225-
const ValueRef k_cache =
226-
prepack_standard_like(graph, k_cache_data, q_projected);
227-
const ValueRef v_cache =
228-
prepack_standard_like(graph, v_cache_data, q_projected);
229-
230237
const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);
231238

232-
add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
233-
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
234-
235239
// Slice caches from 0 to input_pos + sequence_len
236240
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
237241
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257261

258262
// Repeat interleave
259263
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
260-
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
264+
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);
261265

262266
const ValueRef num_repeats =
263267
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331335
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
332336
}
333337

338+
void sdpa_with_kv_cache_impl(
339+
ComputeGraph& graph,
340+
const std::vector<ValueRef>& args) {
341+
int arg_idx = 0;
342+
const ValueRef q_projected = args[arg_idx++];
343+
const ValueRef k_projected = args[arg_idx++];
344+
const ValueRef v_projected = args[arg_idx++];
345+
const ValueRef k_cache_data = args[arg_idx++];
346+
const ValueRef v_cache_data = args[arg_idx++];
347+
const ValueRef input_pos_symint = args[arg_idx++];
348+
const ValueRef sequence_len = args[arg_idx++];
349+
const ValueRef attn_mask = args[arg_idx++];
350+
const ValueRef dropout_p = args[arg_idx++];
351+
const ValueRef is_causal = args[arg_idx++];
352+
const ValueRef scale = args[arg_idx++];
353+
354+
// Output tensors
355+
const ValueRef out = args[arg_idx++];
356+
357+
(void)sequence_len;
358+
359+
const ValueRef k_cache =
360+
prepack_standard_like(graph, k_cache_data, q_projected);
361+
const ValueRef v_cache =
362+
prepack_standard_like(graph, v_cache_data, q_projected);
363+
364+
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
365+
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
366+
367+
sdpa_impl(
368+
graph,
369+
{q_projected,
370+
k_cache,
371+
v_cache,
372+
input_pos_symint,
373+
attn_mask,
374+
dropout_p,
375+
is_causal,
376+
scale,
377+
out});
378+
}
379+
334380
REGISTER_OPERATORS {
335381
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
382+
VK_REGISTER_OP(update_cache.default, update_cache_impl);
383+
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
336384
}
337385

338386
} // namespace vkcompute

0 commit comments

Comments
 (0)