Skip to content

Commit bd4489a

Browse files
authored
doc: add smp class for supporting flash attn (#4009)
1 parent c6d412f commit bd4489a

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494
``state_dict`` contains elements corresponding to only the current
495495
partition, or to the entire model.
496496
497+
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501+
502+
This class supports
503+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+
for PyTorch 2.0.
505+
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
506+
computes attention scores and probabilities,
507+
and then operates the matrix multiplication with value layers.
508+
509+
Through this class, the smp library supports
510+
custom attention masks such as Attention with
511+
Linear Biases (ALiBi), and you can activate them by setting
512+
``triton_flash_attention`` and ``use_alibi`` to ``True``.
513+
514+
Note that the Triton flash attention does not support dropout
515+
on the attention probabilities. It uses standard lower triangular
516+
causal mask when causal mode is enabled. It also runs only
517+
on P4d and P4de instances, with fp16 or bf16.
518+
519+
This class computes the scale factor to apply when computing attention.
520+
By default, ``scale`` is set to ``None``, and it's automatically calculated.
521+
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
522+
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
523+
you must pass a value to ``layer_idx``. If both factors are used, they are
524+
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
525+
This scale calculation can be bypassed if you specify a custom scaling
526+
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
527+
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
528+
is overridden and ignored.
529+
530+
**Parameters**
531+
532+
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
533+
to apply to attention.
534+
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
535+
When ``scale_attention_scores`` is passed, this contributes
536+
``1/sqrt(attention_head_size)`` to the scale factor.
537+
* ``scale_attention_scores`` (boolean): (default: True) determines whether
538+
to multiply 1/sqrt(attention_head_size) to the scale factor.
539+
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``.
540+
The layer id to use for scaling attention by layer id.
541+
It contributes 1/(layer_idx + 1) to the scaling factor.
542+
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
543+
to multiply 1/(layer_idx + 1) to the scale factor.
544+
* ``scale`` (float) (default: None): If passed, this scale factor will be
545+
applied bypassing the all of the previous arguments.
546+
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
547+
implementation of flash attention will be used. This is necessary to supports
548+
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549+
of the kernel doesn’t support dropout.
550+
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
551+
Linear Biases (ALiBi) using the mask provided.
552+
553+
.. method:: forward(self, qkv, attn_mask=None, causal=False)
554+
555+
Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
556+
which represents the output of attention computation.
557+
558+
**Parameters**
559+
560+
* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
561+
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
562+
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
563+
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
564+
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.
565+
566+
When using ALiBi, it needs an attention mask prepared like the following.
567+
568+
.. code:: python
569+
570+
def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
571+
num_attention_heads, alibi_bias_max=8):
572+
573+
device, dtype = attention_mask.device, attention_mask.dtype
574+
alibi_attention_mask = torch.zeros(
575+
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
576+
)
577+
578+
alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
579+
1, 1, 1, seq_length
580+
)
581+
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
582+
m.mul_(alibi_bias_max / num_attention_heads)
583+
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))
584+
585+
alibi_attention_mask.add_(alibi_bias)
586+
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
587+
if attention_mask is not None and attention_mask.bool().any():
588+
alibi_attention_mask.masked_fill(
589+
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
590+
)
591+
592+
return alibi_attention_mask
497593
498594
smdistributed.modelparallel.torch Context Managers and Util Functions
499595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)