@@ -505,64 +505,64 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
505
505
layer_idx=None,
506
506
scale=None,
507
507
triton_flash_attention=False,
508
- use_alibi=False
509
- )
510
-
511
- This FlashAttentionLayer class supports
512
- `FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
513
- It takes the ``qkv `` matrix as argument, computes attention scores and probabilities,
514
- and then does the matrix multiplication with value layer.
515
-
516
- Note that custom attention masks such as Attention with
517
- Linear Biases (ALiBi) are only supported when
518
- ``triton_flash_attention `` and ``use_alibi `` are set to ``True ``.
519
-
520
- Note also that Triton flash attention does not support dropout
521
- on the attention probabilities. It uses standard lower triangular
522
- causal mask when causal mode is enabled. It also runs only
508
+ use_alibi=False)
509
+
510
+ This FlashAttentionLayer class supports
511
+ `FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
512
+ It takes the ``qkv `` matrix as argument, computes attention scores and probabilities,
513
+ and then does the matrix multiplication with value layer.
514
+
515
+ Note that custom attention masks such as Attention with
516
+ Linear Biases (ALiBi) are only supported when
517
+ ``triton_flash_attention `` and ``use_alibi `` are set to ``True ``.
518
+
519
+ Note also that Triton flash attention does not support dropout
520
+ on the attention probabilities. It uses standard lower triangular
521
+ causal mask when causal mode is enabled. It also runs only
523
522
on P4d and P4de instances, with fp16 or bf16.
524
523
525
- This class computes the scale factor to apply when computing attention.
526
- By default, scale is ``None ``, and it's automatically calculated.
527
- When ``scale_attention_scores `` is ``True `` (which is default),
528
- ``attention_head_size `` must be passed. When ``scale_attn_by_layer_idx `` is True,
529
- then ``layer_idx `` must be passed. If both factors are used, they will
530
- be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
531
- This scale calculation can be bypassed by passing a custom scaling
524
+ This class computes the scale factor to apply when computing attention.
525
+ By default, scale is ``None ``, and it's automatically calculated.
526
+ When ``scale_attention_scores `` is ``True `` (which is default),
527
+ ``attention_head_size `` must be passed. When ``scale_attn_by_layer_idx `` is True,
528
+ then ``layer_idx `` must be passed. If both factors are used, they will
529
+ be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
530
+ This scale calculation can be bypassed by passing a custom scaling
532
531
factor if needed with ``scale `` parameter.
533
532
534
533
**Parameters **
535
534
536
- * ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
535
+ * ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
537
536
to apply to attention.
538
- * ``attention_head_size `` (int): Required when scale_attention_scores is True.
539
- When ``scale_attention_scores `` is passed, this contributes
537
+ * ``attention_head_size `` (int): Required when scale_attention_scores is True.
538
+ When ``scale_attention_scores `` is passed, this contributes
540
539
``1/sqrt(attention_head_size) `` to the scale factor.
541
- * ``scale_attention_scores `` (boolean): (default: True) determines whether
540
+ * ``scale_attention_scores `` (boolean): (default: True) determines whether
542
541
to multiply 1/sqrt(attention_head_size) to the scale factor.
543
- * ``layer_idx `` (int): Required when ``scale_attn_by_layer_idx `` is True.
544
- The layer id to use for scaling attention by layer id.
542
+ * ``layer_idx `` (int): Required when ``scale_attn_by_layer_idx `` is `` True ``.
543
+ The layer id to use for scaling attention by layer id.
545
544
It contributes 1/(layer_idx + 1) to the scaling factor.
546
- * ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
545
+ * ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
547
546
to multiply 1/(layer_idx + 1) to the scale factor.
548
- * ``scale `` (float) (default: None): If passed, this scale factor will be
547
+ * ``scale `` (float) (default: None): If passed, this scale factor will be
549
548
applied bypassing the above arguments.
550
- * ``triton_flash_attention `` (bool): (default: False) If passed, Triton
551
- implementation of flash attention will be used. This is necessary to supports
552
- Attention with Linear Biases (ALiBi) (see next arg). Note that this version of the kernel doesn’t support dropout.
553
- * ``use_alibi `` (bool): (default: False) If passed, it enables Attention with
554
- Linear Biases (ALiBi) using the mask provided.
549
+ * ``triton_flash_attention `` (bool): (default: False) If passed, Triton
550
+ implementation of flash attention will be used. This is necessary to supports
551
+ Attention with Linear Biases (ALiBi) (see next arg). Note that this version
552
+ of the kernel doesn’t support dropout.
553
+ * ``use_alibi `` (bool): (default: False) If passed, it enables Attention with
554
+ Linear Biases (ALiBi) using the mask provided.
555
555
556
556
.. method :: forward(self, qkv, attn_mask=None, causal=False)
557
557
558
- Returns a single ``torch.Tensor `` ``(batch_size x num_heads x seq_len x head_size) ``,
558
+ Returns a single ``torch.Tensor `` ``(batch_size x num_heads x seq_len x head_size) ``,
559
559
which represents the output of attention computation.
560
560
561
561
**Parameters **
562
562
563
563
* ``qkv ``: ``torch.Tensor `` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size) ``.
564
- * ``attn_mask ``: ``torch.Tensor `` in the form of ``(batch_size x 1 x 1 x seqlen) ``.
565
- By default it is ``None ``, and usage of this mask needs ``triton_flash_attention ``
564
+ * ``attn_mask ``: ``torch.Tensor `` in the form of ``(batch_size x 1 x 1 x seqlen) ``.
565
+ By default it is ``None ``, and usage of this mask needs ``triton_flash_attention ``
566
566
and ``use_alibi `` to be set. See how to generate the mask in the following code snippet.
567
567
* ``causal ``: When passed, it uses the standard lower triangular mask. The default is ``False ``.
568
568
@@ -594,11 +594,6 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
594
594
595
595
return alibi_attention_mask
596
596
597
-
598
-
599
-
600
-
601
-
602
597
smdistributed.modelparallel.torch Context Managers and Util Functions
603
598
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
604
599
0 commit comments