Skip to content

Commit 474c9cf

Browse files
authored
Merge pull request #2236 from NightMachinery/patch-1
eva.py: fixed bug in applying attention mask
2 parents 7160af4 + 4cca568 commit 474c9cf

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

timm/models/eva.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,12 @@ def forward(
134134
else:
135135
q = q * self.scale
136136
attn = (q @ k.transpose(-2, -1))
137-
attn = attn.softmax(dim=-1)
137+
138138
if attn_mask is not None:
139139
attn_mask = attn_mask.to(torch.bool)
140140
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
141+
attn = attn.softmax(dim=-1)
142+
141143
attn = self.attn_drop(attn)
142144
x = attn @ v
143145

0 commit comments

Comments
 (0)