|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +#define PRECISION ${PRECISION} |
| 12 | + |
| 13 | +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} |
| 14 | + |
| 15 | +${define_required_extensions(DTYPE)} |
| 16 | + |
| 17 | +layout(std430) buffer; |
| 18 | + |
| 19 | +${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} |
| 20 | +${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} |
| 21 | +${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} |
| 22 | +${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} |
| 23 | +${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} |
| 24 | +${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} |
| 25 | +${layout_declare_ubo(B, "ivec3", "xqout_limits")} |
| 26 | +${layout_declare_ubo(B, "ivec3", "xkout_limits")} |
| 27 | + |
| 28 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 29 | + |
| 30 | +layout(constant_id = 3) const int packed_dim = 0; |
| 31 | + |
| 32 | +#include "indexing_utils.h" |
| 33 | + |
| 34 | +/* |
| 35 | + * This shader computes rotary positional embeddings which are used in the Llama |
| 36 | + * model architecture. There are 4 input tensors with the following shapes. |
| 37 | + * Note that head_dim = embedding_dim / num_heads |
| 38 | + * |
| 39 | + * 1. xq (batch_size, sequence_len, num_heads, head_dim) |
| 40 | + * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) |
| 41 | + * 3. freqs_cos (sequence_len, head_dim / 2) |
| 42 | + * 4. freqs_cos (sequence_len, head_dim / 2) |
| 43 | + * |
| 44 | + * Two output tensors are produced, with the same shapes as xq and xk |
| 45 | + * respectively. |
| 46 | + * |
| 47 | + * The computation of rotary positional embeddings can be summarized with the |
| 48 | + * following equations: |
| 49 | + * |
| 50 | + * xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i] |
| 51 | + * xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i] |
| 52 | + * |
| 53 | + * Essentially, taking each row along head_dim of the xq and xk tensors, each |
| 54 | + * row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively). |
| 55 | + * The even components of the output multiply the even components of the inputs |
| 56 | + * with the freqs_cos tensor, and the odd components of the inputs with the |
| 57 | + * freqs_sin tensor. The odd components of the output swap this. Throughout the |
| 58 | + * implementation the even components have the _r suffix and the odd components |
| 59 | + * have the _i suffix; this is a reference to complex numbers which can be used |
| 60 | + * to represent rotations. |
| 61 | + * |
| 62 | + * Note that this implementation assumes that all input tensors have the width |
| 63 | + * dim as the packed dim. |
| 64 | + */ |
| 65 | +void main() { |
| 66 | + // Each thread will write to two output locations to maximize data re-use. |
| 67 | + // One texel loaded from the freqs_cos/freqs_sin tensors can be used to |
| 68 | + // calculate two output texels. |
| 69 | + const ivec3 x_pos_1 = ivec3( |
| 70 | + gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); |
| 71 | + const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); |
| 72 | + |
| 73 | + if (any(greaterThanEqual(x_pos_2, xqout_limits))) { |
| 74 | + return; |
| 75 | + } |
| 76 | + |
| 77 | + const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); |
| 78 | + |
| 79 | + VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); |
| 80 | + VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); |
| 81 | + |
| 82 | + // Compute xqout |
| 83 | + |
| 84 | + VEC4_T x_tex_1 = load_texel(xq, x_pos_1); |
| 85 | + VEC4_T x_tex_2 = load_texel(xq, x_pos_2); |
| 86 | + |
| 87 | + // Separate into even and odd elements |
| 88 | + VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); |
| 89 | + VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); |
| 90 | + |
| 91 | + VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex; |
| 92 | + VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex; |
| 93 | + |
| 94 | + VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); |
| 95 | + VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); |
| 96 | + |
| 97 | + write_texel(xqout, x_pos_1, xout_tex_1); |
| 98 | + write_texel(xqout, x_pos_2, xout_tex_2); |
| 99 | + |
| 100 | + // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout |
| 101 | + // may have a larger height dim than xk and xkout. Only compute xkout if this |
| 102 | + // invocation is still within bounds. |
| 103 | + if (any(greaterThanEqual(x_pos_2, xkout_limits))) { |
| 104 | + return; |
| 105 | + } |
| 106 | + |
| 107 | + // Compute xkout |
| 108 | + |
| 109 | + x_tex_1 = load_texel(xk, x_pos_1); |
| 110 | + x_tex_2 = load_texel(xk, x_pos_2); |
| 111 | + |
| 112 | + x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); |
| 113 | + x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); |
| 114 | + |
| 115 | + xout_r = x_r * cos_tex - x_i * sin_tex; |
| 116 | + xout_i = x_r * sin_tex + x_i * cos_tex; |
| 117 | + |
| 118 | + xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); |
| 119 | + xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); |
| 120 | + |
| 121 | + write_texel(xkout, x_pos_1, xout_tex_1); |
| 122 | + write_texel(xkout, x_pos_2, xout_tex_2); |
| 123 | +} |
0 commit comments