Skip to content

documentation: add smp class for supporting flash attn #4009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
``state_dict`` contains elements corresponding to only the current
partition, or to the entire model.

smdistributed.modelparallel.torch.nn.FlashAttentionLayer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)

This class supports
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
for PyTorch 2.0.
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
computes attention scores and probabilities,
and then operates the matrix multiplication with value layers.

Through this class, the smp library supports
custom attention masks such as Attention with
Linear Biases (ALiBi), and you can activate them by setting
``triton_flash_attention`` and ``use_alibi`` to ``True``.

Note that the Triton flash attention does not support dropout
on the attention probabilities. It uses standard lower triangular
causal mask when causal mode is enabled. It also runs only
on P4d and P4de instances, with fp16 or bf16.

This class computes the scale factor to apply when computing attention.
By default, ``scale`` is set to ``None``, and it's automatically calculated.
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
you must pass a value to ``layer_idx``. If both factors are used, they are
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
This scale calculation can be bypassed if you specify a custom scaling
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
is overridden and ignored.

**Parameters**

* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
to apply to attention.
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
When ``scale_attention_scores`` is passed, this contributes
``1/sqrt(attention_head_size)`` to the scale factor.
* ``scale_attention_scores`` (boolean): (default: True) determines whether
to multiply 1/sqrt(attention_head_size) to the scale factor.
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``.
The layer id to use for scaling attention by layer id.
It contributes 1/(layer_idx + 1) to the scaling factor.
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
to multiply 1/(layer_idx + 1) to the scale factor.
* ``scale`` (float) (default: None): If passed, this scale factor will be
applied bypassing the all of the previous arguments.
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
implementation of flash attention will be used. This is necessary to supports
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
of the kernel doesn’t support dropout.
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
Linear Biases (ALiBi) using the mask provided.

.. method:: forward(self, qkv, attn_mask=None, causal=False)

Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
which represents the output of attention computation.

**Parameters**

* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.

When using ALiBi, it needs an attention mask prepared like the following.

.. code:: python

def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
num_attention_heads, alibi_bias_max=8):

device, dtype = attention_mask.device, attention_mask.dtype
alibi_attention_mask = torch.zeros(
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
)

alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
1, 1, 1, seq_length
)
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
m.mul_(alibi_bias_max / num_attention_heads)
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))

alibi_attention_mask.add_(alibi_bias)
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
if attention_mask is not None and attention_mask.bool().any():
alibi_attention_mask.masked_fill(
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
)

return alibi_attention_mask

smdistributed.modelparallel.torch Context Managers and Util Functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down