Skip to content

Commit 10f51b9

Browse files
kirklandsignSS-JIA
andauthored
[ET-VK] Introduce rotary embedding custom op (#6423)
Pull Request resolved: #6392 ## Context As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details. ghstack-source-id: 249175725 @exported-using-ghexport Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/) Co-authored-by: Stephen Jia <[email protected]>
1 parent 0309854 commit 10f51b9

File tree

5 files changed

+439
-0
lines changed

5 files changed

+439
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
${define_required_extensions(DTYPE)}
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
20+
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
23+
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
24+
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
25+
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
26+
${layout_declare_ubo(B, "ivec3", "xkout_limits")}
27+
28+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
29+
30+
layout(constant_id = 3) const int packed_dim = 0;
31+
32+
#include "indexing_utils.h"
33+
34+
/*
35+
* This shader computes rotary positional embeddings which are used in the Llama
36+
* model architecture. There are 4 input tensors with the following shapes.
37+
* Note that head_dim = embedding_dim / num_heads
38+
*
39+
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
40+
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
41+
* 3. freqs_cos (sequence_len, head_dim / 2)
42+
* 4. freqs_cos (sequence_len, head_dim / 2)
43+
*
44+
* Two output tensors are produced, with the same shapes as xq and xk
45+
* respectively.
46+
*
47+
* The computation of rotary positional embeddings can be summarized with the
48+
* following equations:
49+
*
50+
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
51+
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
52+
*
53+
* Essentially, taking each row along head_dim of the xq and xk tensors, each
54+
* row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively).
55+
* The even components of the output multiply the even components of the inputs
56+
* with the freqs_cos tensor, and the odd components of the inputs with the
57+
* freqs_sin tensor. The odd components of the output swap this. Throughout the
58+
* implementation the even components have the _r suffix and the odd components
59+
* have the _i suffix; this is a reference to complex numbers which can be used
60+
* to represent rotations.
61+
*
62+
* Note that this implementation assumes that all input tensors have the width
63+
* dim as the packed dim.
64+
*/
65+
void main() {
66+
// Each thread will write to two output locations to maximize data re-use.
67+
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
68+
// calculate two output texels.
69+
const ivec3 x_pos_1 = ivec3(
70+
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
71+
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);
72+
73+
if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
74+
return;
75+
}
76+
77+
const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);
78+
79+
VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
80+
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
81+
82+
// Compute xqout
83+
84+
VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
85+
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
86+
87+
// Separate into even and odd elements
88+
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
89+
VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
90+
91+
VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex;
92+
VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex;
93+
94+
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
95+
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
96+
97+
write_texel(xqout, x_pos_1, xout_tex_1);
98+
write_texel(xqout, x_pos_2, xout_tex_2);
99+
100+
// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
101+
// may have a larger height dim than xk and xkout. Only compute xkout if this
102+
// invocation is still within bounds.
103+
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
104+
return;
105+
}
106+
107+
// Compute xkout
108+
109+
x_tex_1 = load_texel(xk, x_pos_1);
110+
x_tex_2 = load_texel(xk, x_pos_2);
111+
112+
x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
113+
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
114+
115+
xout_r = x_r * cos_tex - x_i * sin_tex;
116+
xout_i = x_r * sin_tex + x_i * cos_tex;
117+
118+
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
119+
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
120+
121+
write_texel(xkout, x_pos_1, xout_tex_1);
122+
write_texel(xkout, x_pos_2, xout_tex_2);
123+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
rotary_embedding:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: rotary_embedding
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
12+
13+
namespace vkcompute {
14+
15+
void resize_rotary_embedding_node(
16+
ComputeGraph* graph,
17+
const std::vector<ArgGroup>& args,
18+
const std::vector<ValueRef>& extra_args) {
19+
(void)extra_args;
20+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
21+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
22+
23+
std::vector<int64_t> in_sizes = in->sizes();
24+
// UNCOMMENT BELOW IF NEEDED
25+
// out->virtual_resize(in_sizes);
26+
}
27+
28+
void add_rotary_embedding_node(
29+
ComputeGraph& graph,
30+
const ValueRef xq,
31+
const ValueRef xk,
32+
const ValueRef freqs_cos,
33+
const ValueRef freqs_sin,
34+
const ValueRef xq_out,
35+
const ValueRef xk_out) {
36+
VK_CHECK_COND(graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, xk));
37+
VK_CHECK_COND(graph.size_at<int>(-3, xq) == graph.size_at<int>(-3, xk));
38+
VK_CHECK_COND(
39+
graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, freqs_cos) * 2);
40+
VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin));
41+
42+
VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim);
43+
VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim);
44+
VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim);
45+
VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim);
46+
VK_CHECK_COND(graph.has_standard_axis_map(xq));
47+
VK_CHECK_COND(graph.has_standard_axis_map(xk));
48+
VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos));
49+
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));
50+
51+
std::string kernel_name = "rotary_embedding";
52+
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));
53+
54+
utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out);
55+
global_wg_size[0] /= 2;
56+
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
57+
58+
graph.execute_nodes().emplace_back(new DispatchNode(
59+
graph,
60+
// Shader
61+
VK_KERNEL_FROM_STR(kernel_name),
62+
// Workgroup sizes
63+
global_wg_size,
64+
local_wg_size,
65+
// Inputs and Outputs
66+
{{{xq_out, xk_out}, vkapi::kWrite},
67+
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
68+
// Parameter buffers
69+
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
70+
// Specialization Constants
71+
{},
72+
// Resizing Logic
73+
resize_rotary_embedding_node));
74+
}
75+
76+
void apply_rotary_emb(ComputeGraph& graph, const std::vector<ValueRef>& args) {
77+
const ValueListPtr out_tuple = graph.get_value_list(args[4]);
78+
const ValueRef xq_out = out_tuple->at(0);
79+
const ValueRef xk_out = out_tuple->at(1);
80+
81+
add_rotary_embedding_node(
82+
graph, args[0], args[1], args[2], args[3], xq_out, xk_out);
83+
}
84+
85+
REGISTER_OPERATORS {
86+
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
87+
}
88+
89+
} // namespace vkcompute

0 commit comments

Comments
 (0)