File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed
examples/models/llama2/source_transformation/torchtune/modules Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -70,9 +70,8 @@ class MultiHeadAttention(nn.Module):
70
70
max_seq_len (int): maximum sequence length supported by the model.
71
71
This is needed to compute the RoPE Cache. Default: 4096.
72
72
is_causal (bool): sets the default mask to causal when no mask is provided
73
- attn_dropout (float): dropout value passed onto the
74
- scaled_dot_product_attention function. This argument is ignored if the
75
- self.training is False. Default value is 0.0.
73
+ attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
74
+ This argument is ignored if self.training is False. Default value is 0.0.
76
75
77
76
Raises:
78
77
ValueError: If ``num_heads % num_kv_heads != 0``
@@ -147,7 +146,7 @@ def __init__(
147
146
num_heads = self .num_heads ,
148
147
head_dim = self .head_dim ,
149
148
q_per_kv = self .q_per_kv ,
150
- attn_dropout = self .attn_dropout ,
149
+ attn_dropout = self .attn_dropout if self . training else 0.0 ,
151
150
is_causal = self .is_causal ,
152
151
attention_fn = self ._attention_call ,
153
152
kv_cache = self .kv_cache ,
You can’t perform that action at this time.
0 commit comments