Skip to content

[ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and KV cache update operator + Add RemoveAsserts pass and apply it during LlaMa export #8074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ runtime.python_library(
]
)

runtime.python_library(
name = "remove_asserts",
srcs = ["remove_asserts.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
Expand Down Expand Up @@ -83,6 +96,7 @@ runtime.python_library(
deps = [
":insert_prepack_nodes",
":int4_weight_only_quantizer",
":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":tag_memory_meta_pass"
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
VkInt4WeightOnlyQuantizer,
)
from executorch.backends.vulkan._passes.remove_asserts import (
remove_asserts,
RemoveAssertsTransform,
)
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
Expand All @@ -13,6 +17,8 @@
__all__ = [
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
"remove_asserts",
"RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"TagMemoryMetaPass",
Expand Down
10 changes: 9 additions & 1 deletion backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
)
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
# Mutable buffers will not be marked as constant, but it might as well be
# for the purposes of memory planning. Mark it as a constant tensor so that
# it is handled correctly by the memory planning pass.
if not node.meta["spec"].const:
assert is_param_node(program, node)
node.meta["spec"].const = True
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
Expand All @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
node.replace_all_uses_with(
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
)

program.graph.eliminate_dead_code()
return program
52 changes: 52 additions & 0 deletions backends/vulkan/_passes/remove_asserts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Set, Union

import torch

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.program._program import _get_updated_graph_signature

from torch.export.exported_program import ExportedProgram

OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]


class RemoveAssertsTransform(ExportPass):
"""
Remove operators which perform assertions. These are not possible to execute in
Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
operators.
"""

assert_ops: Set[OpType] = {
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.target in self.assert_ops:
graph_module.graph.erase_node(node)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)


def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
graph_module = edge_program.graph_module
RemoveAssertsTransform()(graph_module)

edge_program._graph_signature = _get_updated_graph_signature(
edge_program.graph_signature, graph_module
)
edge_program._validate()
return edge_program
7 changes: 1 addition & 6 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

from executorch.exir.pass_base import ExportPass, PassResult

from torch.fx.passes.tools_common import NodeList
from torch.fx.passes.utils.fuser_utils import topo_sort

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
for node in graph_module.graph.nodes:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

Expand Down
16 changes: 15 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_op(features: OpFeatures):
def register_sdpa_with_kv_cache_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
Expand All @@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures):
return features


# TODO(ssjia) allow registration after remove assertions pass is implemented
@update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
def register_sdpa_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = False
features.buffer_impl = False
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
return features


@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
Expand Down
100 changes: 74 additions & 26 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,32 @@ void resize_sdpa_out(
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef value = args[arg_idx++];
const ValueRef cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef out = args[arg_idx++];

// Unused variables
(void)out;

VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
VK_CHECK_COND(
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
VK_CHECK_COND(
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));

add_kv_cache_update_node(graph, input_pos_symint, value, cache);
}

void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef k_cache = args[arg_idx++];
const ValueRef v_cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
Expand All @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
// Output tensors
const ValueRef out = args[arg_idx++];

// Unused variables
(void)sequence_len;

// Batches must be 1
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
// k and v projected must have the same shape
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
// head dim must match between tensors
VK_CHECK_COND(
graph.size_at<int32_t>(-1, q_projected) ==
graph.size_at<int32_t>(-1, k_projected));
graph.size_at<int32_t>(-1, k_cache));
// All tensors must have the packed dim be the width (head) dimension
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
// Some variables are not supported yet
VK_CHECK_COND(
graph.val_is_none(dropout_p) ||
Expand All @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
VK_CHECK_COND(graph.val_is_none(attn_mask));

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);

add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);

// Slice caches from 0 to input_pos + sequence_len
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
Expand All @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(

// Repeat interleave
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);

const ValueRef num_repeats =
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
Expand Down Expand Up @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
const ValueRef scale = args[arg_idx++];

// Output tensors
const ValueRef out = args[arg_idx++];

(void)sequence_len;

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

sdpa_impl(
graph,
{q_projected,
k_cache,
v_cache,
input_pos_symint,
attn_mask,
dropout_p,
is_causal,
scale,
out});
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
VK_REGISTER_OP(update_cache.default, update_cache_impl);
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
}

} // namespace vkcompute
5 changes: 5 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pkg_resources
import torch

from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import get_delegation_info

from executorch.devtools.etrecord import generate_etrecord
Expand Down Expand Up @@ -727,6 +729,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
)
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())

if args.mps:
partitioners.append(get_mps_partitioner(args.use_kv_cache))
modelname = f"mps_{modelname}"
Expand Down
Loading