Skip to content

[ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and KV cache update operator #8060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
)
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
# Mutable buffers will not be marked as constant, but it might as well be
# for the purposes of memory planning. Mark it as a constant tensor so that
# it is handled correctly by the memory planning pass.
if not node.meta["spec"].const:
assert is_param_node(program, node)
node.meta["spec"].const = True
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
Expand All @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
node.replace_all_uses_with(
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
)

program.graph.eliminate_dead_code()
return program
4 changes: 1 addition & 3 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
for node in graph_module.graph.nodes:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

Expand Down
16 changes: 15 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_op(features: OpFeatures):
def register_sdpa_with_kv_cache_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
Expand All @@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures):
return features


# TODO(ssjia) allow registration after remove assertions pass is implemented
# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
def register_sdpa_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = False
features.buffer_impl = False
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
return features


@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
Expand Down
100 changes: 74 additions & 26 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,32 @@ void resize_sdpa_out(
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef value = args[arg_idx++];
const ValueRef cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef out = args[arg_idx++];

// Unused variables
(void)out;

VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
VK_CHECK_COND(
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
VK_CHECK_COND(
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));

add_kv_cache_update_node(graph, input_pos_symint, value, cache);
}

void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef k_cache = args[arg_idx++];
const ValueRef v_cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
Expand All @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
// Output tensors
const ValueRef out = args[arg_idx++];

// Unused variables
(void)sequence_len;

// Batches must be 1
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
// k and v projected must have the same shape
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
// head dim must match between tensors
VK_CHECK_COND(
graph.size_at<int32_t>(-1, q_projected) ==
graph.size_at<int32_t>(-1, k_projected));
graph.size_at<int32_t>(-1, k_cache));
// All tensors must have the packed dim be the width (head) dimension
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
// Some variables are not supported yet
VK_CHECK_COND(
graph.val_is_none(dropout_p) ||
Expand All @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
VK_CHECK_COND(graph.val_is_none(attn_mask));

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);

add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);

// Slice caches from 0 to input_pos + sequence_len
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
Expand All @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(

// Repeat interleave
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);

const ValueRef num_repeats =
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
Expand Down Expand Up @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
const ValueRef scale = args[arg_idx++];

// Output tensors
const ValueRef out = args[arg_idx++];

(void)sequence_len;

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

sdpa_impl(
graph,
{q_projected,
k_cache,
v_cache,
input_pos_symint,
attn_mask,
dropout_p,
is_causal,
scale,
out});
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
VK_REGISTER_OP(update_cache.default, update_cache_impl);
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
}

} // namespace vkcompute
Loading