@@ -494,6 +494,110 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494
494
``state_dict`` contains elements corresponding to only the current
495
495
partition, or to the entire model.
496
496
497
+ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499
+
500
+ .. function:: class smdistributed.modelparallel.torch.nn.FlashAttentionLayer(
501
+ attention_dropout_prob=0.1,
502
+ attention_head_size=None,
503
+ scale_attention_scores=True,
504
+ scale_attn_by_layer_idx=False,
505
+ layer_idx=None,
506
+ scale=None,
507
+ triton_flash_attention=False,
508
+ use_alibi=False
509
+ )
510
+
511
+ This FlashAttentionLayer class supports
512
+ `FlashAttention <https://github.com/HazyResearch/flash-attention>`_.
513
+ It takes the ``qkv `` matrix as argument, computes attention scores and probabilities,
514
+ and then does the matrix multiplication with value layer.
515
+
516
+ Note that custom attention masks such as Attention with
517
+ Linear Biases (ALiBi) are only supported when
518
+ ``triton_flash_attention `` and ``use_alibi `` are set to ``True ``.
519
+
520
+ Note also that Triton flash attention does not support dropout
521
+ on the attention probabilities. It uses standard lower triangular
522
+ causal mask when causal mode is enabled. It also runs only
523
+ on P4d and P4de instances, with fp16 or bf16.
524
+
525
+ This class computes the scale factor to apply when computing attention.
526
+ By default, scale is ``None ``, and it's automatically calculated.
527
+ When ``scale_attention_scores `` is ``True `` (which is default),
528
+ ``attention_head_size `` must be passed. When ``scale_attn_by_layer_idx `` is True,
529
+ then ``layer_idx `` must be passed. If both factors are used, they will
530
+ be multiplied ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
531
+ This scale calculation can be bypassed by passing a custom scaling
532
+ factor if needed with ``scale `` parameter.
533
+
534
+ **Parameters **
535
+
536
+ * ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
537
+ to apply to attention.
538
+ * ``attention_head_size `` (int): Required when scale_attention_scores is True.
539
+ When ``scale_attention_scores `` is passed, this contributes
540
+ ``1/sqrt(attention_head_size) `` to the scale factor.
541
+ * ``scale_attention_scores `` (boolean): (default: True) determines whether
542
+ to multiply 1/sqrt(attention_head_size) to the scale factor.
543
+ * ``layer_idx `` (int): Required when ``scale_attn_by_layer_idx `` is True.
544
+ The layer id to use for scaling attention by layer id.
545
+ It contributes 1/(layer_idx + 1) to the scaling factor.
546
+ * ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
547
+ to multiply 1/(layer_idx + 1) to the scale factor.
548
+ * ``scale `` (float) (default: None): If passed, this scale factor will be
549
+ applied bypassing the above arguments.
550
+ * ``triton_flash_attention `` (bool): (default: False) If passed, Triton
551
+ implementation of flash attention will be used. This is necessary to supports
552
+ Attention with Linear Biases (ALiBi) (see next arg). Note that this version of the kernel doesn’t support dropout.
553
+ * ``use_alibi `` (bool): (default: False) If passed, it enables Attention with
554
+ Linear Biases (ALiBi) using the mask provided.
555
+
556
+ .. method :: forward(self, qkv, attn_mask=None, causal=False)
557
+
558
+ Returns a single ``torch.Tensor `` ``(batch_size x num_heads x seq_len x head_size) ``,
559
+ which represents the output of attention computation.
560
+
561
+ **Parameters **
562
+
563
+ * ``qkv ``: ``torch.Tensor `` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size) ``.
564
+ * ``attn_mask ``: ``torch.Tensor `` in the form of ``(batch_size x 1 x 1 x seqlen) ``.
565
+ By default it is ``None ``, and usage of this mask needs ``triton_flash_attention ``
566
+ and ``use_alibi `` to be set. See how to generate the mask in the following code snippet.
567
+ * ``causal ``: When passed, it uses the standard lower triangular mask. The default is ``False ``.
568
+
569
+ When using ALiBi, it needs an attention mask prepared like the following.
570
+
571
+ .. code :: python
572
+
573
+ def generate_alibi_attn_mask (attention_mask , batch_size , seq_length ,
574
+ num_attention_heads , alibi_bias_max = 8 ):
575
+
576
+ device, dtype = attention_mask.device, attention_mask.dtype
577
+ alibi_attention_mask = torch.zeros(
578
+ 1 , num_attention_heads, 1 , seq_length, dtype = dtype, device = device
579
+ )
580
+
581
+ alibi_bias = torch.arange(1 - seq_length, 1 , dtype = dtype, device = device).view(
582
+ 1 , 1 , 1 , seq_length
583
+ )
584
+ m = torch.arange(1 , num_attention_heads + 1 , dtype = dtype, device = device)
585
+ m.mul_(alibi_bias_max / num_attention_heads)
586
+ alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1 , num_attention_heads, 1 , 1 )))
587
+
588
+ alibi_attention_mask.add_(alibi_bias)
589
+ alibi_attention_mask = alibi_attention_mask[... , :seq_length, :seq_length]
590
+ if attention_mask is not None and attention_mask.bool().any():
591
+ alibi_attention_mask.masked_fill(
592
+ attention_mask.bool().view(batch_size, 1 , 1 , seq_length), float (" -inf" )
593
+ )
594
+
595
+ return alibi_attention_mask
596
+
597
+
598
+
599
+
600
+
497
601
498
602
smdistributed.modelparallel.torch Context Managers and Util Functions
499
603
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0 commit comments