@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494
494
``state_dict`` contains elements corresponding to only the current
495
495
partition, or to the entire model.
496
496
497
+ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499
+
500
+ .. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501
+
502
+ This class supports
503
+ `FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504
+ for PyTorch 2.0.
505
+ It takes the ``qkv `` matrix as an argument through its ``forward `` class method,
506
+ computes attention scores and probabilities,
507
+ and then operates the matrix multiplication with value layers.
508
+
509
+ Through this class, the smp library supports
510
+ custom attention masks such as Attention with
511
+ Linear Biases (ALiBi), and you can activate them by setting
512
+ ``triton_flash_attention `` and ``use_alibi `` to ``True ``.
513
+
514
+ Note that the Triton flash attention does not support dropout
515
+ on the attention probabilities. It uses standard lower triangular
516
+ causal mask when causal mode is enabled. It also runs only
517
+ on P4d and P4de instances, with fp16 or bf16.
518
+
519
+ This class computes the scale factor to apply when computing attention.
520
+ By default, ``scale `` is set to ``None ``, and it's automatically calculated.
521
+ When ``scale_attention_scores `` is ``True `` (which is default), you must pass a value
522
+ to ``attention_head_size ``. When ``scale_attn_by_layer_idx `` is ``True ``,
523
+ you must pass a value to ``layer_idx ``. If both factors are used, they are
524
+ multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
525
+ This scale calculation can be bypassed if you specify a custom scaling
526
+ factor to ``scale ``. In other words, if you specify a value to ``scale ``, the set of parameters
527
+ (``scale_attention_scores ``, ``attention_head_size ``, ``scale_attn_by_layer_idx ``, ``layer_idx ``)
528
+ is overridden and ignored.
529
+
530
+ **Parameters **
531
+
532
+ * ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
533
+ to apply to attention.
534
+ * ``attention_head_size `` (int): Required when ``scale_attention_scores `` is True.
535
+ When ``scale_attention_scores `` is passed, this contributes
536
+ ``1/sqrt(attention_head_size) `` to the scale factor.
537
+ * ``scale_attention_scores `` (boolean): (default: True) determines whether
538
+ to multiply 1/sqrt(attention_head_size) to the scale factor.
539
+ * ``layer_idx `` (int): Required when ``scale_attn_by_layer_idx `` is ``True ``.
540
+ The layer id to use for scaling attention by layer id.
541
+ It contributes 1/(layer_idx + 1) to the scaling factor.
542
+ * ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
543
+ to multiply 1/(layer_idx + 1) to the scale factor.
544
+ * ``scale `` (float) (default: None): If passed, this scale factor will be
545
+ applied bypassing the all of the previous arguments.
546
+ * ``triton_flash_attention `` (bool): (default: False) If passed, Triton
547
+ implementation of flash attention will be used. This is necessary to supports
548
+ Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549
+ of the kernel doesn’t support dropout.
550
+ * ``use_alibi `` (bool): (default: False) If passed, it enables Attention with
551
+ Linear Biases (ALiBi) using the mask provided.
552
+
553
+ .. method :: forward(self, qkv, attn_mask=None, causal=False)
554
+
555
+ Returns a single ``torch.Tensor `` ``(batch_size x num_heads x seq_len x head_size) ``,
556
+ which represents the output of attention computation.
557
+
558
+ **Parameters **
559
+
560
+ * ``qkv ``: ``torch.Tensor `` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size) ``.
561
+ * ``attn_mask ``: ``torch.Tensor `` in the form of ``(batch_size x 1 x 1 x seqlen) ``.
562
+ By default it is ``None ``, and usage of this mask needs ``triton_flash_attention ``
563
+ and ``use_alibi `` to be set. See how to generate the mask in the following code snippet.
564
+ * ``causal ``: When passed, it uses the standard lower triangular mask. The default is ``False ``.
565
+
566
+ When using ALiBi, it needs an attention mask prepared like the following.
567
+
568
+ .. code :: python
569
+
570
+ def generate_alibi_attn_mask (attention_mask , batch_size , seq_length ,
571
+ num_attention_heads , alibi_bias_max = 8 ):
572
+
573
+ device, dtype = attention_mask.device, attention_mask.dtype
574
+ alibi_attention_mask = torch.zeros(
575
+ 1 , num_attention_heads, 1 , seq_length, dtype = dtype, device = device
576
+ )
577
+
578
+ alibi_bias = torch.arange(1 - seq_length, 1 , dtype = dtype, device = device).view(
579
+ 1 , 1 , 1 , seq_length
580
+ )
581
+ m = torch.arange(1 , num_attention_heads + 1 , dtype = dtype, device = device)
582
+ m.mul_(alibi_bias_max / num_attention_heads)
583
+ alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1 , num_attention_heads, 1 , 1 )))
584
+
585
+ alibi_attention_mask.add_(alibi_bias)
586
+ alibi_attention_mask = alibi_attention_mask[... , :seq_length, :seq_length]
587
+ if attention_mask is not None and attention_mask.bool().any():
588
+ alibi_attention_mask.masked_fill(
589
+ attention_mask.bool().view(batch_size, 1 , 1 , seq_length), float (" -inf" )
590
+ )
591
+
592
+ return alibi_attention_mask
497
593
498
594
smdistributed.modelparallel.torch Context Managers and Util Functions
499
595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0 commit comments