Skip to content

Commit 7e8ec2d

Browse files
committed
polish
1 parent be8fa1c commit 7e8ec2d

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -499,34 +499,39 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
499499
500500
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501501
502-
This FlashAttentionLayer class supports
503-
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
504-
It takes the ``qkv`` matrix as argument, computes attention scores and probabilities,
505-
and then does the matrix multiplication with value layer.
506-
507-
Note that custom attention masks such as Attention with
508-
Linear Biases (ALiBi) are only supported when
509-
``triton_flash_attention`` and ``use_alibi`` are set to ``True``.
510-
511-
Note also that Triton flash attention does not support dropout
502+
This class supports
503+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+
for PyTorch 2.0.
505+
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
506+
computes attention scores and probabilities,
507+
and then operates the matrix multiplication with value layers.
508+
509+
Through this class, the smp library supports
510+
custom attention masks such as Attention with
511+
Linear Biases (ALiBi), and you can activate them by setting
512+
``triton_flash_attention`` and ``use_alibi`` to ``True``.
513+
514+
Note that the Triton flash attention does not support dropout
512515
on the attention probabilities. It uses standard lower triangular
513516
causal mask when causal mode is enabled. It also runs only
514517
on P4d and P4de instances, with fp16 or bf16.
515518

516519
This class computes the scale factor to apply when computing attention.
517-
By default, scale is ``None``, and it's automatically calculated.
518-
When ``scale_attention_scores`` is ``True`` (which is default),
519-
``attention_head_size`` must be passed. When ``scale_attn_by_layer_idx`` is True,
520-
then ``layer_idx`` must be passed. If both factors are used, they will
521-
be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
522-
This scale calculation can be bypassed by passing a custom scaling
523-
factor if needed with ``scale`` parameter.
520+
By default, ``scale`` is set to ``None``, and it's automatically calculated.
521+
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
522+
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
523+
you must pass a value to ``layer_idx``. If both factors are used, they are
524+
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
525+
This scale calculation can be bypassed if you specify a custom scaling
526+
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
527+
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
528+
is overridden and ignored.
524529

525530
**Parameters**
526531

527532
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
528533
to apply to attention.
529-
* ``attention_head_size`` (int): Required when scale_attention_scores is True.
534+
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
530535
When ``scale_attention_scores`` is passed, this contributes
531536
``1/sqrt(attention_head_size)`` to the scale factor.
532537
* ``scale_attention_scores`` (boolean): (default: True) determines whether
@@ -537,7 +542,7 @@ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
537542
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
538543
to multiply 1/(layer_idx + 1) to the scale factor.
539544
* ``scale`` (float) (default: None): If passed, this scale factor will be
540-
applied bypassing the above arguments.
545+
applied bypassing the all of the previous arguments.
541546
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
542547
implementation of flash attention will be used. This is necessary to supports
543548
Attention with Linear Biases (ALiBi) (see next arg). Note that this version

0 commit comments

Comments
 (0)