Skip to content

Commit be8fa1c

Browse files
committed
fix class
1 parent be4fdfb commit be8fa1c

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -497,15 +497,7 @@ smdistributed.modelparallel.torch.DistributedOptimizer
497497
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498498
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499499
500-
.. function:: class smdistributed.modelparallel.torch.nn.FlashAttentionLayer(
501-
attention_dropout_prob=0.1,
502-
attention_head_size=None,
503-
scale_attention_scores=True,
504-
scale_attn_by_layer_idx=False,
505-
layer_idx=None,
506-
scale=None,
507-
triton_flash_attention=False,
508-
use_alibi=False)
500+
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
509501
510502
This FlashAttentionLayer class supports
511503
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_.

0 commit comments

Comments
 (0)