Skip to content

Commit 4433d94

Browse files
committed
add smp class for supporting flash attn
1 parent c6d412f commit 4433d94

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,110 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494
``state_dict`` contains elements corresponding to only the current
495495
partition, or to the entire model.
496496
497+
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+
.. function:: class smdistributed.modelparallel.torch.nn.FlashAttentionLayer(
501+
attention_dropout_prob=0.1,
502+
attention_head_size=None,
503+
scale_attention_scores=True,
504+
scale_attn_by_layer_idx=False,
505+
layer_idx=None,
506+
scale=None,
507+
triton_flash_attention=False,
508+
use_alibi=False
509+
)
510+
511+
This FlashAttentionLayer class supports
512+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
513+
It takes the ``qkv`` matrix as argument, computes attention scores and probabilities,
514+
and then does the matrix multiplication with value layer.
515+
516+
Note that custom attention masks such as Attention with
517+
Linear Biases (ALiBi) are only supported when
518+
``triton_flash_attention`` and ``use_alibi`` are set to ``True``.
519+
520+
Note also that Triton flash attention does not support dropout
521+
on the attention probabilities. It uses standard lower triangular
522+
causal mask when causal mode is enabled. It also runs only
523+
on P4d and P4de instances, with fp16 or bf16.
524+
525+
This class computes the scale factor to apply when computing attention.
526+
By default, scale is ``None``, and it's automatically calculated.
527+
When ``scale_attention_scores`` is ``True`` (which is default),
528+
``attention_head_size`` must be passed. When ``scale_attn_by_layer_idx`` is True,
529+
then ``layer_idx`` must be passed. If both factors are used, they will
530+
be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
531+
This scale calculation can be bypassed by passing a custom scaling
532+
factor if needed with ``scale`` parameter.
533+
534+
**Parameters**
535+
536+
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
537+
to apply to attention.
538+
* ``attention_head_size`` (int): Required when scale_attention_scores is True.
539+
When ``scale_attention_scores`` is passed, this contributes
540+
``1/sqrt(attention_head_size)`` to the scale factor.
541+
* ``scale_attention_scores`` (boolean): (default: True) determines whether
542+
to multiply 1/sqrt(attention_head_size) to the scale factor.
543+
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is True.
544+
The layer id to use for scaling attention by layer id.
545+
It contributes 1/(layer_idx + 1) to the scaling factor.
546+
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
547+
to multiply 1/(layer_idx + 1) to the scale factor.
548+
* ``scale`` (float) (default: None): If passed, this scale factor will be
549+
applied bypassing the above arguments.
550+
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
551+
implementation of flash attention will be used. This is necessary to supports
552+
Attention with Linear Biases (ALiBi) (see next arg). Note that this version of the kernel doesn’t support dropout.
553+
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
554+
Linear Biases (ALiBi) using the mask provided.
555+
556+
.. method:: forward(self, qkv, attn_mask=None, causal=False)
557+
558+
Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
559+
which represents the output of attention computation.
560+
561+
**Parameters**
562+
563+
* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
564+
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
565+
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
566+
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
567+
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.
568+
569+
When using ALiBi, it needs an attention mask prepared like the following.
570+
571+
.. code:: python
572+
573+
def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
574+
num_attention_heads, alibi_bias_max=8):
575+
576+
device, dtype = attention_mask.device, attention_mask.dtype
577+
alibi_attention_mask = torch.zeros(
578+
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
579+
)
580+
581+
alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
582+
1, 1, 1, seq_length
583+
)
584+
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
585+
m.mul_(alibi_bias_max / num_attention_heads)
586+
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))
587+
588+
alibi_attention_mask.add_(alibi_bias)
589+
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
590+
if attention_mask is not None and attention_mask.bool().any():
591+
alibi_attention_mask.masked_fill(
592+
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
593+
)
594+
595+
return alibi_attention_mask
596+
597+
598+
599+
600+
497601
498602
smdistributed.modelparallel.torch Context Managers and Util Functions
499603
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)