@@ -499,34 +499,39 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
499
499
500
500
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501
501
502
- This FlashAttentionLayer class supports
503
- `FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
504
- It takes the ``qkv `` matrix as argument, computes attention scores and probabilities,
505
- and then does the matrix multiplication with value layer.
506
-
507
- Note that custom attention masks such as Attention with
508
- Linear Biases (ALiBi) are only supported when
509
- ``triton_flash_attention `` and ``use_alibi `` are set to ``True ``.
510
-
511
- Note also that Triton flash attention does not support dropout
502
+ This class supports
503
+ `FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504
+ for PyTorch 2.0.
505
+ It takes the ``qkv `` matrix as an argument through its ``forward `` class method,
506
+ computes attention scores and probabilities,
507
+ and then operates the matrix multiplication with value layers.
508
+
509
+ Through this class, the smp library supports
510
+ custom attention masks such as Attention with
511
+ Linear Biases (ALiBi), and you can activate them by setting
512
+ ``triton_flash_attention `` and ``use_alibi `` to ``True ``.
513
+
514
+ Note that the Triton flash attention does not support dropout
512
515
on the attention probabilities. It uses standard lower triangular
513
516
causal mask when causal mode is enabled. It also runs only
514
517
on P4d and P4de instances, with fp16 or bf16.
515
518
516
519
This class computes the scale factor to apply when computing attention.
517
- By default, scale is ``None ``, and it's automatically calculated.
518
- When ``scale_attention_scores `` is ``True `` (which is default),
519
- ``attention_head_size `` must be passed. When ``scale_attn_by_layer_idx `` is True,
520
- then ``layer_idx `` must be passed. If both factors are used, they will
521
- be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
522
- This scale calculation can be bypassed by passing a custom scaling
523
- factor if needed with ``scale `` parameter.
520
+ By default, ``scale `` is set to ``None ``, and it's automatically calculated.
521
+ When ``scale_attention_scores `` is ``True `` (which is default), you must pass a value
522
+ to ``attention_head_size ``. When ``scale_attn_by_layer_idx `` is ``True ``,
523
+ you must pass a value to ``layer_idx ``. If both factors are used, they are
524
+ multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
525
+ This scale calculation can be bypassed if you specify a custom scaling
526
+ factor to ``scale ``. In other words, if you specify a value to ``scale ``, the set of parameters
527
+ (``scale_attention_scores ``, ``attention_head_size ``, ``scale_attn_by_layer_idx ``, ``layer_idx ``)
528
+ is overridden and ignored.
524
529
525
530
**Parameters **
526
531
527
532
* ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
528
533
to apply to attention.
529
- * ``attention_head_size `` (int): Required when scale_attention_scores is True.
534
+ * ``attention_head_size `` (int): Required when `` scale_attention_scores `` is True.
530
535
When ``scale_attention_scores `` is passed, this contributes
531
536
``1/sqrt(attention_head_size) `` to the scale factor.
532
537
* ``scale_attention_scores `` (boolean): (default: True) determines whether
@@ -537,7 +542,7 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
537
542
* ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
538
543
to multiply 1/(layer_idx + 1) to the scale factor.
539
544
* ``scale `` (float) (default: None): If passed, this scale factor will be
540
- applied bypassing the above arguments.
545
+ applied bypassing the all of the previous arguments.
541
546
* ``triton_flash_attention `` (bool): (default: False) If passed, Triton
542
547
implementation of flash attention will be used. This is necessary to supports
543
548
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
0 commit comments