Skip to content

Commit 0850b9b

Browse files
authored
Merge branch 'master' into chore/jumpstart-deprecation-messages
2 parents 18fffc8 + fdc0ac1 commit 0850b9b

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494
``state_dict`` contains elements corresponding to only the current
495495
partition, or to the entire model.
496496
497+
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501+
502+
This class supports
503+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+
for PyTorch 2.0.
505+
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
506+
computes attention scores and probabilities,
507+
and then operates the matrix multiplication with value layers.
508+
509+
Through this class, the smp library supports
510+
custom attention masks such as Attention with
511+
Linear Biases (ALiBi), and you can activate them by setting
512+
``triton_flash_attention`` and ``use_alibi`` to ``True``.
513+
514+
Note that the Triton flash attention does not support dropout
515+
on the attention probabilities. It uses standard lower triangular
516+
causal mask when causal mode is enabled. It also runs only
517+
on P4d and P4de instances, with fp16 or bf16.
518+
519+
This class computes the scale factor to apply when computing attention.
520+
By default, ``scale`` is set to ``None``, and it's automatically calculated.
521+
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
522+
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
523+
you must pass a value to ``layer_idx``. If both factors are used, they are
524+
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
525+
This scale calculation can be bypassed if you specify a custom scaling
526+
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
527+
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
528+
is overridden and ignored.
529+
530+
**Parameters**
531+
532+
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
533+
to apply to attention.
534+
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
535+
When ``scale_attention_scores`` is passed, this contributes
536+
``1/sqrt(attention_head_size)`` to the scale factor.
537+
* ``scale_attention_scores`` (boolean): (default: True) determines whether
538+
to multiply 1/sqrt(attention_head_size) to the scale factor.
539+
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``.
540+
The layer id to use for scaling attention by layer id.
541+
It contributes 1/(layer_idx + 1) to the scaling factor.
542+
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
543+
to multiply 1/(layer_idx + 1) to the scale factor.
544+
* ``scale`` (float) (default: None): If passed, this scale factor will be
545+
applied bypassing the all of the previous arguments.
546+
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
547+
implementation of flash attention will be used. This is necessary to supports
548+
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549+
of the kernel doesn’t support dropout.
550+
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
551+
Linear Biases (ALiBi) using the mask provided.
552+
553+
.. method:: forward(self, qkv, attn_mask=None, causal=False)
554+
555+
Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
556+
which represents the output of attention computation.
557+
558+
**Parameters**
559+
560+
* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
561+
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
562+
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
563+
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
564+
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.
565+
566+
When using ALiBi, it needs an attention mask prepared like the following.
567+
568+
.. code:: python
569+
570+
def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
571+
num_attention_heads, alibi_bias_max=8):
572+
573+
device, dtype = attention_mask.device, attention_mask.dtype
574+
alibi_attention_mask = torch.zeros(
575+
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
576+
)
577+
578+
alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
579+
1, 1, 1, seq_length
580+
)
581+
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
582+
m.mul_(alibi_bias_max / num_attention_heads)
583+
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))
584+
585+
alibi_attention_mask.add_(alibi_bias)
586+
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
587+
if attention_mask is not None and attention_mask.bool().any():
588+
alibi_attention_mask.masked_fill(
589+
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
590+
)
591+
592+
return alibi_attention_mask
497593
498594
smdistributed.modelparallel.torch Context Managers and Util Functions
499595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
doc8==0.10.1
2-
Pygments==2.11.2
2+
Pygments==2.15.0

tests/scripts/run-notebook-test.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,7 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN"
126126
./amazon-sagemaker-examples/advanced_functionality/kmeans_bring_your_own_model/kmeans_bring_your_own_model.ipynb \
127127
./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \
128128
./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_highlevel/kmeans_mnist.ipynb \
129-
./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb \
130-
./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_gluon_sentiment/mxnet_sentiment_analysis_with_gluon.ipynb \
131-
./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_onnx_export/mxnet_onnx_export.ipynb \
132129
./amazon-sagemaker-examples/sagemaker-python-sdk/scikit_learn_randomforest/Sklearn_on_SageMaker_end2end.ipynb \
133130
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_moving_from_framework_mode_to_script_mode/tensorflow_moving_from_framework_mode_to_script_mode.ipynb \
134-
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_script_mode_pipe_mode/tensorflow_script_mode_pipe_mode.ipynb \
135-
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_serving_using_elastic_inference_with_your_own_model/tensorflow_serving_pretrained_model_elastic_inference.ipynb \
136131
137132
(DeleteLifeCycleConfig "$LIFECYCLE_CONFIG_NAME")

0 commit comments

Comments
 (0)