Skip to content

Commit d3c5244

Browse files
authored
fix module level API docstring (#2869)
* correct module level api docstring * flake8 format correction * fix broken links
1 parent 71d6e31 commit d3c5244

File tree

4 files changed

+414
-237
lines changed

4 files changed

+414
-237
lines changed

docs/tutorials/llm.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ Verified for distributed inference mode via DeepSpeed
3030

3131
*Note*: The above verified models (including other models in the same model family, like "codellama/CodeLlama-7b-hf" from LLAMA family) are well supported with all optimizations like indirect access KV cache, fused ROPE, and prepacked TPP Linear (fp32/bf16). We are working in progress to better support the models in the tables with various data types. In addition, more models will be optimized in the future.
3232

33-
Please check `LLM best known practice <../../examples/cpu/inference/python/llm>`_ for instructions to install/setup environment and example scripts.
33+
Please check `LLM best known practice <https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0%2Bcpu/examples/cpu/inference/python/llm>`_ for instructions to install/setup environment and example scripts.
3434

3535
Module Level Optimization API for customized LLM (Prototype)
3636
------------------------------------------------------------
3737

3838
In the past year, LLM has been flourishing with many open-sourced models contributed to the community, while researchers are building their own LLMs from transformer blocks with variants in implementation details. To help LLM researchers and developers improve their productivity, Intel® Extension for PyTorch* provides module level optimizations for commonly used LLM modules and functionalities, which are operators or certain operator combinations in nature.
3939

40-
Please check `LLM module level optimization practice <../../examples/cpu/inference/python/llm-modeling>`_ to better understand how to use `module level APIs <api_doc.html#llm-module-level-optimizations>`_ to optimize your LLM and achieve better performance.
40+
Please check `LLM module level optimization practice <https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0%2Bcpu/examples/cpu/inference/python/llm-modeling>`_ to better understand how to use `module level APIs <api_doc.html#llm-module-level-optimizations-prototype>`_ to optimize your LLM and achieve better performance.
4141

4242
Demos
4343
-----

intel_extension_for_pytorch/llm/functional/fusions.py

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,31 @@ def rotary_embedding(
2020
):
2121
r"""
2222
Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864)
23-
on the `query ` or `key` before their multi-head attention computation.
23+
on the `query ` or `key` before their multi-head attention computation.
24+
2425
Args:
25-
- query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
26-
[batch size, sequence length, num_head/num_kv_head, head_dim]
27-
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
28-
- sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
29-
- rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
30-
- head_dim (int) : head dim from the input shape.
31-
- rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
32-
so the offset is 1.
33-
if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
34-
so the offset is rotary_dim/2.
35-
- position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids
36-
for the input. The shape should be [batch size, sequence length].
26+
query, key (torch.Tensor) : inputs to be applied with position embeddings,
27+
taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
28+
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
29+
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
30+
generated to be applied on query/key.
31+
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
32+
head_dim (int) : head dim from the input shape.
33+
rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
34+
so the offset is 1.
35+
36+
if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
37+
so the offset is rotary_dim/2.
38+
39+
position_ids (torch.Tensor): Default is None and optional if sin/cos is provided.
40+
The according position_ids for the input. The shape should be [batch size, sequence length].
41+
3742
Return
38-
- query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
39-
or [num_tokens, num_head/num_kv_head, head_dim].
43+
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
44+
or [num_tokens, num_head/num_kv_head, head_dim].
4045
4146
"""
47+
4248
return RotaryEmbedding.apply_function(
4349
query, key, sin, cos, rotary_dim, rotary_half, position_ids
4450
)
@@ -48,12 +54,14 @@ def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float):
4854
r"""
4955
Applies RMSnorm on the input (hidden states).
5056
(see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)
57+
5158
Args:
52-
- hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
53-
- weight (torch.Tensor): the weight to apply RMSnorm.
54-
- eps (float) : the variance_epsilon to apply RMSnorm.
59+
hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
60+
weight (torch.Tensor): the weight to apply RMSnorm.
61+
eps (float) : the variance_epsilon to apply RMSnorm.
5562
5663
"""
64+
5765
return RMSNorm.apply_function(hidden_states, weight, eps)
5866

5967

@@ -67,12 +75,14 @@ def fast_layer_norm(
6775
r"""
6876
Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)
6977
on the input (hidden states).
78+
7079
Args:
71-
- hidden_states(torch.Tensor) : the input tensor to apply normalization.
72-
- normalized_shape (int or list) or torch.Size) input shape from an expected input of size.
73-
- weight (torch.Tensor): the weight to apply normalization.
74-
- bias (torch.Tensor): an additive bias for normalization.
75-
- eps (float): a value added to the denominator for numerical stability.
80+
hidden_states(torch.Tensor) : the input tensor to apply normalization.
81+
normalized_shape (int or list) or torch.Size) input shape from an
82+
expected input of size.
83+
weight (torch.Tensor): the weight to apply normalization.
84+
bias (torch.Tensor): an additive bias for normalization.
85+
eps (float): a value added to the denominator for numerical stability.
7686
7787
"""
7888

@@ -103,33 +113,49 @@ def indirect_access_kv_cache_attention(
103113
buffers(key and value use different buffers) to store all key/value hidden states and beam index information.
104114
It can use beam index history to decide which beam should be used by a timestamp and this information will
105115
generate an offset to access the kv_cache buffer.
116+
106117
Data Format:
107-
- The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
108-
the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
109-
All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
110-
111-
forward
112-
- query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
113-
- key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
114-
- value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
115-
- scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
116-
- layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
117-
key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
118-
value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
119-
beam-idx: history beam idx, shape:(max_seq, beam*batch);
120-
seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
121-
- head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
122-
- attention_mask(torch.Tensor): Attention mask information.
123-
- text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer).
118+
119+
The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
120+
the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
121+
All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
122+
123+
Args:
124+
query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
125+
key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
126+
value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
127+
scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
128+
layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
129+
130+
- key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
131+
132+
- value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
133+
134+
- beam-idx: history beam idx, shape:(max_seq, beam*batch);
135+
136+
- seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
137+
138+
head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
139+
attention_mask(torch.Tensor): Attention mask information.
140+
text_max_length (int) : the max length of kv cache to be used for generation
141+
(allocate the pre-cache buffer).
124142
125143
Return:
126-
- attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).
127-
- attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now.
128-
- new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
144+
attn_output: weighted value which is the output of scale dot product.
145+
shape (beam*batch, seq_len, head_num, head_size).
146+
147+
attn_weights: the output tensor of the first matmul in scale dot product
148+
which is not supported by kernel now.
149+
150+
new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
129151
130152
Notes:
131-
- How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
132-
see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
153+
How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
154+
see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
155+
156+
.. highlight:: python
157+
.. code-block:: python
158+
133159
def _reorder_cache(
134160
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
135161
) -> Tuple[Tuple[torch.Tensor]]:
@@ -141,6 +167,7 @@ def _reorder_cache(
141167
return past_key_values
142168
143169
"""
170+
144171
return IndirectAccessKVCacheAttention.apply_function(
145172
query,
146173
key,
@@ -174,23 +201,30 @@ def varlen_attention(
174201
):
175202
r"""
176203
Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value
177-
(see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
178-
and accept the variant (different) sequence length among the query, key and value.
204+
(see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
205+
and accept the variant (different) sequence length among the query, key and value.
206+
207+
This module does not have args for `module init`.
208+
209+
`forward()`
179210
180211
Args:
181-
module init: this module does not have args for module init
182-
forward:
183-
- query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.
184-
- key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.
185-
- value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.
186-
- out (torch.Tensor): buffer to get the results, the shape is the same as query.
187-
- seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length.
188-
- seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length.
189-
- max_seqlen_q (int): max/total sequence length of query.
190-
- max_seqlen_k (int): max/total sequence length of key.
191-
- pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
192-
- softmax_scale (float): scaling factor applied is prior to softmax.
193-
- is_causal (bool): whether to apply causal attention masking, default is True.
212+
query (torch.Tensor): shape [query_tokens, num_head, head_size],
213+
where tokens is total sequence length among batch size.
214+
key (torch.Tensor): shape [key_tokens, num_head, head_size],
215+
where tokens is total sequence length among batch size.
216+
value (torch.Tensor): shape [value_tokens, num_head, head_size],
217+
where tokens is total sequence length among batch size.
218+
out (torch.Tensor): buffer to get the results, the shape is the same as query.
219+
seqlen_q (torch.Tensor): shape [batch_size + 1],
220+
points the current query_tokens among total sequence length.
221+
seqlen_k (torch.Tensor): shape [batch_size + 1],
222+
points the current key_tokens among total sequence length.
223+
max_seqlen_q (int): max/total sequence length of query.
224+
max_seqlen_k (int): max/total sequence length of key.
225+
pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
226+
softmax_scale (float): scaling factor applied is prior to softmax.
227+
is_causal (bool): whether to apply causal attention masking, default is True.
194228
195229
"""
196230
return VarlenAttention.apply_function(

0 commit comments

Comments
 (0)