Skip to content

Commit 863ff13

Browse files
committed
Recent TT updates
1 parent 8796a5f commit 863ff13

File tree

1 file changed

+3
-4
lines changed
  • examples/models/llama2/source_transformation/torchtune/modules

1 file changed

+3
-4
lines changed

examples/models/llama2/source_transformation/torchtune/modules/mha.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ class MultiHeadAttention(nn.Module):
7070
max_seq_len (int): maximum sequence length supported by the model.
7171
This is needed to compute the RoPE Cache. Default: 4096.
7272
is_causal (bool): sets the default mask to causal when no mask is provided
73-
attn_dropout (float): dropout value passed onto the
74-
scaled_dot_product_attention function. This argument is ignored if the
75-
self.training is False. Default value is 0.0.
73+
attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
74+
This argument is ignored if self.training is False. Default value is 0.0.
7675
7776
Raises:
7877
ValueError: If ``num_heads % num_kv_heads != 0``
@@ -147,7 +146,7 @@ def __init__(
147146
num_heads=self.num_heads,
148147
head_dim=self.head_dim,
149148
q_per_kv=self.q_per_kv,
150-
attn_dropout=self.attn_dropout,
149+
attn_dropout=self.attn_dropout if self.training else 0.0,
151150
is_causal=self.is_causal,
152151
attention_fn=self._attention_call,
153152
kv_cache=self.kv_cache,

0 commit comments

Comments
 (0)