Skip to content

Commit 6e871c3

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement SDPA + KV-Cache operator (#5799)
Summary: Pull Request resolved: #5799 ## Context As title, this diff adds an implementation for a fused SDPA + KV-Cache update operator which will be used in LLaMA models. Currently the SDPA portion of the operator is implemented via it's consituent operators, but a future optimization opportunity would be to implement a single flash attention shader. ## Reference Implementation For future reference, a reference implementation of the SDPA + KV cache update mechanism is shown below. This reference implementation was originally used to check intermediate outputs but in the end I decided to compare against the `sdpa_with_kv_cache` operator in `extension/llm` for simplicity. ``` at::Tensor convert_boolean_attn_mask( const at::Tensor& attn_mask, caffe2::TypeMeta dtype) { // Convert boolean mask to additive mask; need to invert mask to indicate what // to mask *out*. if (attn_mask.dtype() == at::kBool) { return at::where( attn_mask.logical_not(), -std::numeric_limits<double>::infinity(), at::scalar_tensor( 0.0, at::TensorOptions().dtype(dtype).device(attn_mask.device()))); } // Otherwise, attn_mask represents an additive attention tensor return attn_mask; } at::Tensor construct_attention_mask( const at::Tensor& q, const at::Tensor& k_cache, const int start_pos) { const int max_seq_len = k_cache.size(1); const int seq_len = q.size(1); at::Tensor attn_mask_base = at::ones({max_seq_len, start_pos + seq_len}, q.options().dtype(at::kBool)) .tril(); at::Tensor attn_mask_sliced = at::slice(attn_mask_base, 0, start_pos, start_pos + seq_len); attn_mask_sliced = convert_boolean_attn_mask(attn_mask_sliced, q.dtype()); return attn_mask_sliced; } std::vector<at::Tensor> sdpa_reference_impl( const at::Tensor& q_projected, const at::Tensor& k_projected, const at::Tensor& v_projected, at::Tensor& key_cache, at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, const c10::optional<at::Tensor> __attn_mask_ignored, const double dropout_p, const bool is_causal, const c10::optional<double> scale) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); at::Tensor key_cache_updated = at::slice_scatter( key_cache, k_projected, 1, start_pos, start_pos + k_projected.size(1)); at::Tensor value_cache_updated = at::slice_scatter( value_cache, v_projected, 1, start_pos, start_pos + v_projected.size(1)); at::Tensor key_cache_sliced = at::slice(key_cache_updated, 1, 0, start_pos + q_projected.size(1)); at::Tensor value_cache_sliced = at::slice(value_cache_updated, 1, 0, start_pos + q_projected.size(1)); at::Tensor q_transposed = q_projected.transpose(1, 2); at::Tensor k_transposed = key_cache_sliced.transpose(1, 2); at::Tensor v_transposed = value_cache_sliced.transpose(1, 2); // Skip doing repeat_interleave; assume that num_attention_heads == // num_kv_heads float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor k_transposed_2 = k_transposed.transpose(-2, -1); at::Tensor attn_weight_prescale = at::matmul(q_transposed, k_transposed_2); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); return { out.transpose(1, 2), key_cache_sliced, value_cache_sliced, q_transposed, k_transposed, v_transposed, k_transposed_2, attn_weight_prescale, attn_weight, attn_weight_softmax, out, }; } ``` ghstack-source-id: 246640547 Reviewed By: kimishpatel Differential Revision: D63724114 fbshipit-source-id: c85afc2f8eade8e0ac6e348eabbe608e5a0efce6
1 parent 0186c7f commit 6e871c3

File tree

9 files changed

+1128
-11
lines changed

9 files changed

+1128
-11
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class ComputeGraph final {
254254
#undef GET_AND_CHECK_VAL_AS_TYPE_FNS
255255

256256
inline bool val_is_none(const ValueRef idx) {
257-
return values_.at(idx).isNone();
257+
return idx == kDummyValueRef ? true : values_.at(idx).isNone();
258258
}
259259

260260
inline TypeTag get_val_type(const ValueRef idx) {
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#version 450 core
2+
3+
#define PRECISION ${PRECISION}
4+
5+
#define T ${buffer_scalar_type(DTYPE)}
6+
7+
${define_active_storage_type(STORAGE)}
8+
${define_required_extensions(DTYPE)}
9+
10+
layout(std430) buffer;
11+
12+
#include "indexing_utils.h"
13+
14+
${layout_declare_tensor(B, "w", "cache", DTYPE, STORAGE)}
15+
${layout_declare_tensor(B, "r", "projected", DTYPE, STORAGE)}
16+
$if STORAGE == "buffer":
17+
${layout_declare_ubo(B, "int", "projected_numel")}
18+
${layout_declare_ubo(B, "ivec4", "cache_strides")}
19+
${layout_declare_ubo(B, "int", "input_pos")}
20+
$else:
21+
${layout_declare_ubo(B, "ivec3", "projected_limits")}
22+
${layout_declare_ubo(B, "int", "input_pos")}
23+
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
/*
28+
* t_cache will have sizes of (max_batch_size, max_seq_len, n_heads, head_dim).
29+
* t_projected will have sizes of (batch_size, seq_len, n_heads, head_dim).
30+
*
31+
* The cache update inserts the values of t_projected into t_cache at the index
32+
* specified by input_pos at the seq_len dimension. It is equivalent to calling
33+
34+
* t_cache = t_cache.slice_scatter(
35+
* t_projected, dim=1, start=input_pos, end=input_pos+seq_len)
36+
*
37+
* Note that this shader is implemented assuming that max_batch_size is 1.
38+
*/
39+
40+
#ifdef USING_BUFFER
41+
42+
/***************************
43+
** Buffer Implementation **
44+
***************************/
45+
46+
void main() {
47+
int projected_bufi = int(gl_GlobalInvocationID.x);
48+
// Bump cache index forward by input_pos elements along the seq_len dimension.
49+
// cache_strides contains the strides of the cache tensor.
50+
int cache_bufi = input_pos * cache_strides.z + projected_bufi;
51+
if (projected_bufi >= projected_numel) {
52+
return;
53+
}
54+
cache[cache_bufi] = projected[projected_bufi];
55+
}
56+
57+
#else
58+
59+
/****************************
60+
** Texture Implementation **
61+
****************************/
62+
63+
// Note that this shader assumes the that tensors are width packed, i.e.
64+
// packed_dim = 0
65+
void main() {
66+
const ivec3 projected_pos = ivec3(gl_GlobalInvocationID);
67+
68+
if (any(greaterThanEqual(projected_pos, projected_limits))) {
69+
return;
70+
}
71+
72+
const ivec3 cache_pos = ivec3(
73+
projected_pos.x,
74+
projected_pos.y,
75+
projected_pos.z + input_pos);
76+
77+
write_texel(cache, cache_pos, load_texel(projected, projected_pos));
78+
}
79+
80+
#endif // USING_BUFFER
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
kv_cache_update:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
STORAGE:
13+
- VALUE: buffer
14+
- VALUE: texture3d
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: kv_cache_update
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
15+
${define_active_storage_type(STORAGE)}
16+
${define_required_extensions(DTYPE)}
17+
18+
#extension GL_EXT_control_flow_attributes : require
19+
20+
layout(std430) buffer;
21+
22+
${layout_declare_tensor(B, "rw", "attn_weight", DTYPE, STORAGE)}
23+
24+
$if STORAGE == "buffer":
25+
${layout_declare_ubo(B, "ivec4", "attn_weight_sizes")}
26+
${layout_declare_ubo(B, "ivec4", "attn_weight_strides")}
27+
$else:
28+
${layout_declare_ubo(B, "ivec3", "attn_weight_limits")}
29+
30+
${layout_declare_ubo(B, "int", "input_pos")}
31+
${layout_declare_ubo(B, "float", "scale")}
32+
33+
34+
#include "indexing_utils.h"
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
// Negative infinity is represented by having sign bit be 1, all exponent bits
39+
// be 1, all mantissa bits be 0.
40+
#define NEGATIVE_INF_BITS 0xFF800000
41+
const float negative_infinity = NEGATIVE_INF_BITS;
42+
43+
#ifdef USING_BUFFER
44+
45+
/*
46+
* This implementations applies a scale and mask to the attention weight tensor
47+
* of an SDPA block. The sizes of the attention weight is
48+
* (batch_size, n_heads, seq_len, input_pos + seq_len)
49+
* Conceptually the weights represent the relationship between each token in the
50+
* sequence with each token preceding it.
51+
*
52+
* The scale applied is 1.0 / sqrt(head_dim_length)
53+
*
54+
* The mask applied is a bit more complicated. Imagine you create a square
55+
* matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the
56+
* lower triangular section of the matrix to -inf. Then, slice the matrix along
57+
* the row dimension starting from input_pos to input_pos + seq_len. You end up
58+
* with a partial mask with size (seq_len, input_pos + seq_len). This is the
59+
* mask that is applied to the attention weight.
60+
*
61+
* In the shader, instead of generating the mask, the index of the elment is
62+
* inspected to determine if it would have been masked. Given an element at
63+
* tensor index (n, c, h, w), it would be masked if w < h + input_pos.
64+
*/
65+
66+
/***************************
67+
** Buffer Implementation **
68+
***************************/
69+
70+
void main() {
71+
const ivec4 attn_weight_idx = ivec4(
72+
gl_GlobalInvocationID.x,
73+
gl_GlobalInvocationID.y,
74+
gl_GlobalInvocationID.z,
75+
0);
76+
77+
if (any(greaterThanEqual(attn_weight_idx, attn_weight_sizes))) {
78+
return;
79+
}
80+
81+
const T scale_conv = T(scale);
82+
83+
const int attn_weight_id = tidx_to_bufi(attn_weight_idx, attn_weight_strides);
84+
if (attn_weight_idx.x <= attn_weight_idx.y + input_pos) {
85+
attn_weight[attn_weight_id] = attn_weight[attn_weight_id] * scale_conv;
86+
} else {
87+
attn_weight[attn_weight_id] = T(negative_infinity);
88+
}
89+
}
90+
91+
#else
92+
93+
/****************************
94+
** Texture Implementation **
95+
****************************/
96+
97+
/*
98+
* This implementation assumes that the attention weight is width packed, i.e.
99+
* the packed dim of the attn_weight is 0.
100+
*/
101+
void main() {
102+
const ivec3 attn_weight_pos = ivec3(gl_GlobalInvocationID);
103+
104+
if (any(greaterThanEqual(attn_weight_pos, attn_weight_limits))) {
105+
return;
106+
}
107+
108+
vec4 outtex = imageLoad(attn_weight, attn_weight_pos) * scale;
109+
110+
// Mask out the upper triangular of attn_weight to -inf
111+
[[unroll]] for (int i = 0; i < 4; ++i) {
112+
if (attn_weight_pos.x * 4 + i > attn_weight_pos.y + input_pos) {
113+
outtex[i] = negative_infinity;
114+
}
115+
}
116+
117+
write_texel(attn_weight, attn_weight_pos, outtex);
118+
}
119+
120+
#endif // USING_BUFFER
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
sdpa_attn_weight_scale_and_mask:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: buffer
5+
generate_variant_forall:
6+
STORAGE:
7+
- VALUE: buffer
8+
- VALUE: texture3d
9+
DTYPE:
10+
- VALUE: half
11+
- VALUE: float
12+
shader_variants:
13+
- NAME: sdpa_attn_weight_scale_and_mask

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,10 @@ void resize_matmul_node(
4848
const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
4949
: utils::val_at(-1, mat2->sizes());
5050

51-
std::vector<int64_t> new_out_sizes(3);
52-
if (mat1->sizes().size() == 2) {
53-
new_out_sizes.resize(2);
54-
new_out_sizes.at(0) = out_cols;
55-
new_out_sizes.at(1) = out_rows;
56-
} else {
57-
new_out_sizes.at(0) = mat1->sizes().at(0);
58-
new_out_sizes.at(1) = out_cols;
59-
new_out_sizes.at(2) = out_rows;
60-
}
51+
const int64_t out_dim = out->dim();
52+
std::vector<int64_t> new_out_sizes(mat1->sizes());
53+
new_out_sizes.at(out_dim - 1) = out_rows;
54+
new_out_sizes.at(out_dim - 2) = out_cols;
6155

6256
out->virtual_resize(new_out_sizes);
6357
}

0 commit comments

Comments
 (0)