@@ -20,25 +20,31 @@ def rotary_embedding(
20
20
):
21
21
r"""
22
22
Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864)
23
- on the `query ` or `key` before their multi-head attention computation.
23
+ on the `query ` or `key` before their multi-head attention computation.
24
+
24
25
Args:
25
- - query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
26
- [batch size, sequence length, num_head/num_kv_head, head_dim]
27
- or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
28
- - sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
29
- - rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
30
- - head_dim (int) : head dim from the input shape.
31
- - rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
32
- so the offset is 1.
33
- if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
34
- so the offset is rotary_dim/2.
35
- - position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids
36
- for the input. The shape should be [batch size, sequence length].
26
+ query, key (torch.Tensor) : inputs to be applied with position embeddings,
27
+ taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
28
+ or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
29
+ sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
30
+ generated to be applied on query/key.
31
+ rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
32
+ head_dim (int) : head dim from the input shape.
33
+ rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
34
+ so the offset is 1.
35
+
36
+ if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
37
+ so the offset is rotary_dim/2.
38
+
39
+ position_ids (torch.Tensor): Default is None and optional if sin/cos is provided.
40
+ The according position_ids for the input. The shape should be [batch size, sequence length].
41
+
37
42
Return
38
- - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
39
- or [num_tokens, num_head/num_kv_head, head_dim].
43
+ query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
44
+ or [num_tokens, num_head/num_kv_head, head_dim].
40
45
41
46
"""
47
+
42
48
return RotaryEmbedding .apply_function (
43
49
query , key , sin , cos , rotary_dim , rotary_half , position_ids
44
50
)
@@ -48,12 +54,14 @@ def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float):
48
54
r"""
49
55
Applies RMSnorm on the input (hidden states).
50
56
(see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)
57
+
51
58
Args:
52
- - hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
53
- - weight (torch.Tensor): the weight to apply RMSnorm.
54
- - eps (float) : the variance_epsilon to apply RMSnorm.
59
+ hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
60
+ weight (torch.Tensor): the weight to apply RMSnorm.
61
+ eps (float) : the variance_epsilon to apply RMSnorm.
55
62
56
63
"""
64
+
57
65
return RMSNorm .apply_function (hidden_states , weight , eps )
58
66
59
67
@@ -67,12 +75,14 @@ def fast_layer_norm(
67
75
r"""
68
76
Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)
69
77
on the input (hidden states).
78
+
70
79
Args:
71
- - hidden_states(torch.Tensor) : the input tensor to apply normalization.
72
- - normalized_shape (int or list) or torch.Size) input shape from an expected input of size.
73
- - weight (torch.Tensor): the weight to apply normalization.
74
- - bias (torch.Tensor): an additive bias for normalization.
75
- - eps (float): a value added to the denominator for numerical stability.
80
+ hidden_states(torch.Tensor) : the input tensor to apply normalization.
81
+ normalized_shape (int or list) or torch.Size) input shape from an
82
+ expected input of size.
83
+ weight (torch.Tensor): the weight to apply normalization.
84
+ bias (torch.Tensor): an additive bias for normalization.
85
+ eps (float): a value added to the denominator for numerical stability.
76
86
77
87
"""
78
88
@@ -103,33 +113,49 @@ def indirect_access_kv_cache_attention(
103
113
buffers(key and value use different buffers) to store all key/value hidden states and beam index information.
104
114
It can use beam index history to decide which beam should be used by a timestamp and this information will
105
115
generate an offset to access the kv_cache buffer.
116
+
106
117
Data Format:
107
- - The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
108
- the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
109
- All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
110
-
111
- forward
112
- - query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
113
- - key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
114
- - value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
115
- - scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
116
- - layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
117
- key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
118
- value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
119
- beam-idx: history beam idx, shape:(max_seq, beam*batch);
120
- seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
121
- - head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
122
- - attention_mask(torch.Tensor): Attention mask information.
123
- - text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer).
118
+
119
+ The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
120
+ the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
121
+ All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
122
+
123
+ Args:
124
+ query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
125
+ key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
126
+ value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
127
+ scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
128
+ layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
129
+
130
+ - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
131
+
132
+ - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
133
+
134
+ - beam-idx: history beam idx, shape:(max_seq, beam*batch);
135
+
136
+ - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
137
+
138
+ head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
139
+ attention_mask(torch.Tensor): Attention mask information.
140
+ text_max_length (int) : the max length of kv cache to be used for generation
141
+ (allocate the pre-cache buffer).
124
142
125
143
Return:
126
- - attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).
127
- - attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now.
128
- - new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
144
+ attn_output: weighted value which is the output of scale dot product.
145
+ shape (beam*batch, seq_len, head_num, head_size).
146
+
147
+ attn_weights: the output tensor of the first matmul in scale dot product
148
+ which is not supported by kernel now.
149
+
150
+ new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
129
151
130
152
Notes:
131
- - How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
132
- see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
153
+ How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
154
+ see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
155
+
156
+ .. highlight:: python
157
+ .. code-block:: python
158
+
133
159
def _reorder_cache(
134
160
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
135
161
) -> Tuple[Tuple[torch.Tensor]]:
@@ -141,6 +167,7 @@ def _reorder_cache(
141
167
return past_key_values
142
168
143
169
"""
170
+
144
171
return IndirectAccessKVCacheAttention .apply_function (
145
172
query ,
146
173
key ,
@@ -174,23 +201,30 @@ def varlen_attention(
174
201
):
175
202
r"""
176
203
Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value
177
- (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
178
- and accept the variant (different) sequence length among the query, key and value.
204
+ (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
205
+ and accept the variant (different) sequence length among the query, key and value.
206
+
207
+ This module does not have args for `module init`.
208
+
209
+ `forward()`
179
210
180
211
Args:
181
- module init: this module does not have args for module init
182
- forward:
183
- - query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.
184
- - key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.
185
- - value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.
186
- - out (torch.Tensor): buffer to get the results, the shape is the same as query.
187
- - seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length.
188
- - seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length.
189
- - max_seqlen_q (int): max/total sequence length of query.
190
- - max_seqlen_k (int): max/total sequence length of key.
191
- - pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
192
- - softmax_scale (float): scaling factor applied is prior to softmax.
193
- - is_causal (bool): whether to apply causal attention masking, default is True.
212
+ query (torch.Tensor): shape [query_tokens, num_head, head_size],
213
+ where tokens is total sequence length among batch size.
214
+ key (torch.Tensor): shape [key_tokens, num_head, head_size],
215
+ where tokens is total sequence length among batch size.
216
+ value (torch.Tensor): shape [value_tokens, num_head, head_size],
217
+ where tokens is total sequence length among batch size.
218
+ out (torch.Tensor): buffer to get the results, the shape is the same as query.
219
+ seqlen_q (torch.Tensor): shape [batch_size + 1],
220
+ points the current query_tokens among total sequence length.
221
+ seqlen_k (torch.Tensor): shape [batch_size + 1],
222
+ points the current key_tokens among total sequence length.
223
+ max_seqlen_q (int): max/total sequence length of query.
224
+ max_seqlen_k (int): max/total sequence length of key.
225
+ pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
226
+ softmax_scale (float): scaling factor applied is prior to softmax.
227
+ is_causal (bool): whether to apply causal attention masking, default is True.
194
228
195
229
"""
196
230
return VarlenAttention .apply_function (
0 commit comments