Skip to content

Commit d651dcd

Browse files
authored
Add mask_val option to StaticAttentionmask
Differential Revision: D70255619 Pull Request resolved: #8736
1 parent 1eb2f94 commit d651dcd

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/models/llama/static_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,19 @@ def update(
125125

126126

127127
class StaticAttentionMask:
128-
def __init__(self, input_len, cache_len, style):
128+
def __init__(self, input_len, cache_len, style, mask_val=float("-inf")):
129129
self.input_len = input_len
130130
self.cache_len = cache_len
131131
assert style in ("shift_pointer", "smart_mask")
132132
self.style = style
133+
self.mask_val = mask_val
133134
self.unmasked_len = 0
134135
self.tensor = torch.zeros(1, input_len, input_len + cache_len)
135136
self.reset()
136137

137138
def reset(self):
138139
self.unmasked_len = 0
139-
self.tensor[:, :, : self.cache_len] = float("-inf")
140+
self.tensor[:, :, : self.cache_len] = self.mask_val
140141

141142
def unmask(self, new_unmasked_len):
142143
if new_unmasked_len <= 0:

0 commit comments

Comments
 (0)