Skip to content

Commit 8f2b220

Browse files
committed
[ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and KV cache update operator + Add RemoveAsserts pass and apply it during LlaMa export
**Note**: This diff is a combination of D68919676 (#8068) and D68919678 (no pull request). I decided to combine the two because of problems with `ghexport`, which was having some problems exporting the second diff, as well as the fact that both diffs are needed for `export_llama` to work so it makes more sense to just have a single diff. ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. ## Context Recently, some assertion ops were added to the Llama source code. Unfortunately, this causes issues for the Vulkan delegate because runtime assertions are not yet supported in Vulkan and the assertion ops cause graph breaks due to not being supported. To prevent graph breaks when delegating to Vulkan, apply a pass to remove assertion ops during the llama export. Differential Revision: [D68922404](https://our.internmc.facebook.com/intern/diff/D68922404/) [ghstack-poisoned]
1 parent afc5a50 commit 8f2b220

File tree

8 files changed

+176
-34
lines changed

8 files changed

+176
-34
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ runtime.python_library(
3030
]
3131
)
3232

33+
runtime.python_library(
34+
name = "remove_asserts",
35+
srcs = ["remove_asserts.py"],
36+
visibility = [
37+
"//executorch/backends/...",
38+
],
39+
deps = [
40+
"//caffe2:torch",
41+
"//executorch/exir:pass_base",
42+
"//executorch/exir/dialects:lib",
43+
],
44+
)
45+
3346
runtime.python_library(
3447
name = "remove_local_scalar_dense",
3548
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -83,6 +96,7 @@ runtime.python_library(
8396
deps = [
8497
":insert_prepack_nodes",
8598
":int4_weight_only_quantizer",
99+
":remove_asserts",
86100
":remove_local_scalar_dense",
87101
":remove_redundant_ops",
88102
":tag_memory_meta_pass"

backends/vulkan/_passes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
33
VkInt4WeightOnlyQuantizer,
44
)
5+
from executorch.backends.vulkan._passes.remove_asserts import (
6+
remove_asserts,
7+
RemoveAssertsTransform,
8+
)
59
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
610
RemoveLocalScalarDenseOpsTransform,
711
)
@@ -13,6 +17,8 @@
1317
__all__ = [
1418
"insert_prepack_nodes",
1519
"VkInt4WeightOnlyQuantizer",
20+
"remove_asserts",
21+
"RemoveAssertsTransform",
1622
"RemoveLocalScalarDenseOpsTransform",
1723
"RemoveRedundantOpsTransform",
1824
"TagMemoryMetaPass",

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6060
)
6161
# This pass assumes that the SpecPropPass() has already been applied
6262
assert "spec" in node.meta
63+
# Mutable buffers will not be marked as constant, but it might as well be
64+
# for the purposes of memory planning. Mark it as a constant tensor so that
65+
# it is handled correctly by the memory planning pass.
66+
if not node.meta["spec"].const:
67+
assert is_param_node(program, node)
68+
node.meta["spec"].const = True
6369
# Validate that the original node is marked as a constant. Constant tensors
6470
# do not participate in memory planning.
6571
assert node.meta["spec"].const
@@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6874
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
6975
# memory object.
7076
prepack_node.meta["spec"].mem_obj_id = -1
71-
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
77+
node.replace_all_uses_with(
78+
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
79+
)
7280

7381
program.graph.eliminate_dead_code()
7482
return program
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set, Union
10+
11+
import torch
12+
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.program._program import _get_updated_graph_signature
16+
17+
from torch.export.exported_program import ExportedProgram
18+
19+
OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
20+
21+
22+
class RemoveAssertsTransform(ExportPass):
23+
"""
24+
Remove operators which perform assertions. These are not possible to execute in
25+
Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
26+
operators.
27+
"""
28+
29+
assert_ops: Set[OpType] = {
30+
torch.ops.aten._assert_scalar.default,
31+
torch.ops.aten.sym_constrain_range_for_size.default,
32+
}
33+
34+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
35+
for node in graph_module.graph.nodes:
36+
if node.target in self.assert_ops:
37+
graph_module.graph.erase_node(node)
38+
39+
graph_module.graph.eliminate_dead_code()
40+
graph_module.recompile()
41+
return PassResult(graph_module, True)
42+
43+
44+
def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
45+
graph_module = edge_program.graph_module
46+
RemoveAssertsTransform()(graph_module)
47+
48+
edge_program._graph_signature = _get_updated_graph_signature(
49+
edge_program.graph_signature, graph_module
50+
)
51+
edge_program._validate()
52+
return edge_program

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
2525

26-
from torch.fx.passes.tools_common import NodeList
27-
from torch.fx.passes.utils.fuser_utils import topo_sort
28-
2926
logger: logging.Logger = logging.getLogger("")
3027
logger.setLevel(logging.INFO)
3128

@@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
220217

221218
# noqa
222219
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
223-
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
224-
225-
for node in sorted_nodes:
220+
for node in graph_module.graph.nodes:
226221
if not self.should_annotate(node) or self.should_delay_annotation(node):
227222
continue
228223

backends/vulkan/op_registry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):
478478

479479

480480
@update_features("llama::sdpa_with_kv_cache")
481-
def register_sdpa_op(features: OpFeatures):
481+
def register_sdpa_with_kv_cache_op(features: OpFeatures):
482482
features.texture_impl = TextureImplFeatures(
483483
valid_packed_dims={PackedDim.WIDTH},
484484
)
@@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures):
489489
return features
490490

491491

492+
# TODO(ssjia) allow registration after remove assertions pass is implemented
493+
@update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
494+
def register_sdpa_ops(features: OpFeatures):
495+
features.texture_impl = TextureImplFeatures(
496+
valid_packed_dims={PackedDim.WIDTH},
497+
)
498+
features.resize_fn = False
499+
features.buffer_impl = False
500+
features.texture_impl = TextureImplFeatures(
501+
valid_packed_dims={PackedDim.WIDTH},
502+
)
503+
return features
504+
505+
492506
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
493507
def register_rotary_emb_op(features: OpFeatures):
494508
features.texture_impl = TextureImplFeatures(

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,32 @@ void resize_sdpa_out(
176176
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
177177
}
178178

179-
void sdpa_with_kv_cache_impl(
180-
ComputeGraph& graph,
181-
const std::vector<ValueRef>& args) {
179+
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
180+
int arg_idx = 0;
181+
const ValueRef value = args[arg_idx++];
182+
const ValueRef cache = args[arg_idx++];
183+
const ValueRef input_pos_symint = args[arg_idx++];
184+
const ValueRef out = args[arg_idx++];
185+
186+
// Unused variables
187+
(void)out;
188+
189+
VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
190+
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
191+
VK_CHECK_COND(
192+
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
193+
VK_CHECK_COND(
194+
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));
195+
196+
add_kv_cache_update_node(graph, input_pos_symint, value, cache);
197+
}
198+
199+
void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
182200
int arg_idx = 0;
183201
const ValueRef q_projected = args[arg_idx++];
184-
const ValueRef k_projected = args[arg_idx++];
185-
const ValueRef v_projected = args[arg_idx++];
186-
const ValueRef k_cache_data = args[arg_idx++];
187-
const ValueRef v_cache_data = args[arg_idx++];
202+
const ValueRef k_cache = args[arg_idx++];
203+
const ValueRef v_cache = args[arg_idx++];
188204
const ValueRef input_pos_symint = args[arg_idx++];
189-
const ValueRef sequence_len = args[arg_idx++];
190205
const ValueRef attn_mask = args[arg_idx++];
191206
const ValueRef dropout_p = args[arg_idx++];
192207
const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195210
// Output tensors
196211
const ValueRef out = args[arg_idx++];
197212

198-
// Unused variables
199-
(void)sequence_len;
200-
201213
// Batches must be 1
202214
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
203-
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
204-
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
215+
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
216+
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
205217
// k and v projected must have the same shape
206-
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
218+
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
207219
// head dim must match between tensors
208220
VK_CHECK_COND(
209221
graph.size_at<int32_t>(-1, q_projected) ==
210-
graph.size_at<int32_t>(-1, k_projected));
222+
graph.size_at<int32_t>(-1, k_cache));
211223
// All tensors must have the packed dim be the width (head) dimension
212224
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
213-
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
214-
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
225+
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
226+
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
215227
// Some variables are not supported yet
216228
VK_CHECK_COND(
217229
graph.val_is_none(dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222234
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
223235
VK_CHECK_COND(graph.val_is_none(attn_mask));
224236

225-
const ValueRef k_cache =
226-
prepack_standard_like(graph, k_cache_data, q_projected);
227-
const ValueRef v_cache =
228-
prepack_standard_like(graph, v_cache_data, q_projected);
229-
230237
const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);
231238

232-
add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
233-
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
234-
235239
// Slice caches from 0 to input_pos + sequence_len
236240
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
237241
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257261

258262
// Repeat interleave
259263
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
260-
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
264+
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);
261265

262266
const ValueRef num_repeats =
263267
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331335
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
332336
}
333337

338+
void sdpa_with_kv_cache_impl(
339+
ComputeGraph& graph,
340+
const std::vector<ValueRef>& args) {
341+
int arg_idx = 0;
342+
const ValueRef q_projected = args[arg_idx++];
343+
const ValueRef k_projected = args[arg_idx++];
344+
const ValueRef v_projected = args[arg_idx++];
345+
const ValueRef k_cache_data = args[arg_idx++];
346+
const ValueRef v_cache_data = args[arg_idx++];
347+
const ValueRef input_pos_symint = args[arg_idx++];
348+
const ValueRef sequence_len = args[arg_idx++];
349+
const ValueRef attn_mask = args[arg_idx++];
350+
const ValueRef dropout_p = args[arg_idx++];
351+
const ValueRef is_causal = args[arg_idx++];
352+
const ValueRef scale = args[arg_idx++];
353+
354+
// Output tensors
355+
const ValueRef out = args[arg_idx++];
356+
357+
(void)sequence_len;
358+
359+
const ValueRef k_cache =
360+
prepack_standard_like(graph, k_cache_data, q_projected);
361+
const ValueRef v_cache =
362+
prepack_standard_like(graph, v_cache_data, q_projected);
363+
364+
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
365+
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
366+
367+
sdpa_impl(
368+
graph,
369+
{q_projected,
370+
k_cache,
371+
v_cache,
372+
input_pos_symint,
373+
attn_mask,
374+
dropout_p,
375+
is_causal,
376+
scale,
377+
out});
378+
}
379+
334380
REGISTER_OPERATORS {
335381
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
382+
VK_REGISTER_OP(update_cache.default, update_cache_impl);
383+
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
336384
}
337385

338386
} // namespace vkcompute

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import pkg_resources
2323
import torch
24+
25+
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2426
from executorch.devtools.backend_debug import get_delegation_info
2527

2628
from executorch.devtools.etrecord import generate_etrecord
@@ -727,6 +729,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
727729
)
728730
modelname = f"vulkan_{modelname}"
729731

732+
# Need to remove asserts from the graph to prevent graph breaks
733+
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
734+
730735
if args.mps:
731736
partitioners.append(get_mps_partitioner(args.use_kv_cache))
732737
modelname = f"mps_{modelname}"

0 commit comments

Comments
 (0)