|
7 | 7 | from executorch.examples.models.llama.rope import Rope
|
8 | 8 | from executorch.examples.models.llama.static_attention import (
|
9 | 9 | StaticAttention,
|
| 10 | + StaticAttentionMask, |
10 | 11 | StaticKVCache,
|
11 | 12 | )
|
12 | 13 |
|
@@ -92,48 +93,54 @@ def test_with_cache(self):
|
92 | 93 | n_chunks = 3
|
93 | 94 | chunk_len = config.max_seq_len // n_chunks
|
94 | 95 | cache_len = config.max_seq_len - chunk_len
|
95 |
| - mask = torch.zeros(1, chunk_len, cache_len + chunk_len) |
96 |
| - mask[:, :, :cache_len] = float("-inf") |
97 |
| - mask[:, :, cache_len:] = torch.triu( |
98 |
| - torch.full((1, chunk_len, chunk_len), float("-inf")), |
99 |
| - diagonal=1, |
100 |
| - ) |
101 |
| - k_caches = { |
102 |
| - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
103 |
| - 1, cache_len, config.head_dim |
104 |
| - ) |
105 |
| - for i in range(config.n_kv_heads) |
106 |
| - } |
107 |
| - v_caches = { |
108 |
| - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
109 |
| - 1, cache_len, config.head_dim |
110 |
| - ) |
111 |
| - for i in range(config.n_kv_heads) |
112 |
| - } |
113 |
| - ys = [] |
114 |
| - for i in range(n_chunks): |
115 |
| - y_i, attn_update = static_attn( |
116 |
| - x[:, i * chunk_len : (i + 1) * chunk_len, :], |
117 |
| - freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
118 |
| - freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
119 |
| - mask=mask, |
120 |
| - in_cache_state=(k_caches, v_caches), |
121 |
| - out_cache_state=({}, {}), |
| 96 | + |
| 97 | + def test_with_style(style): |
| 98 | + mask = StaticAttentionMask(chunk_len, cache_len, style=style) |
| 99 | + mask.tensor[:, :, cache_len:] = torch.triu( |
| 100 | + torch.full((1, chunk_len, chunk_len), float("-inf")), |
| 101 | + diagonal=1, |
122 | 102 | )
|
123 |
| - ys.append(y_i) |
124 |
| - mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0 |
125 |
| - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
126 |
| - for cache_id, update in k_cache_updates.items(): |
127 |
| - k_caches[cache_id] = StaticKVCache.apply_update( |
128 |
| - k_caches[cache_id], update |
| 103 | + k_caches = { |
| 104 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 105 | + 1, cache_len, config.head_dim |
129 | 106 | )
|
130 |
| - for cache_id, update in v_cache_updates.items(): |
131 |
| - v_caches[cache_id] = StaticKVCache.apply_update( |
132 |
| - v_caches[cache_id], update |
| 107 | + for i in range(config.n_kv_heads) |
| 108 | + } |
| 109 | + v_caches = { |
| 110 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 111 | + 1, cache_len, config.head_dim |
133 | 112 | )
|
134 |
| - |
135 |
| - y = torch.cat(ys, dim=1) |
136 |
| - self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
| 113 | + for i in range(config.n_kv_heads) |
| 114 | + } |
| 115 | + ys = [] |
| 116 | + for i in range(n_chunks): |
| 117 | + y_i, attn_update = static_attn( |
| 118 | + x[:, i * chunk_len : (i + 1) * chunk_len, :], |
| 119 | + freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
| 120 | + freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
| 121 | + mask=mask.tensor, |
| 122 | + in_cache_state=(k_caches, v_caches), |
| 123 | + out_cache_state=({}, {}), |
| 124 | + ) |
| 125 | + ys.append(y_i) |
| 126 | + mask.unmask(chunk_len) |
| 127 | + k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
| 128 | + |
| 129 | + if i < n_chunks - 1: |
| 130 | + for cache_id, update in k_cache_updates.items(): |
| 131 | + k_caches[cache_id] = StaticKVCache.apply_update( |
| 132 | + k_caches[cache_id], update, pos=chunk_len * i, style=style |
| 133 | + ) |
| 134 | + for cache_id, update in v_cache_updates.items(): |
| 135 | + v_caches[cache_id] = StaticKVCache.apply_update( |
| 136 | + v_caches[cache_id], update, pos=chunk_len * i, style=style |
| 137 | + ) |
| 138 | + |
| 139 | + y = torch.cat(ys, dim=1) |
| 140 | + self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
| 141 | + |
| 142 | + test_with_style("shift_pointer") |
| 143 | + test_with_style("smart_mask") |
137 | 144 |
|
138 | 145 | def test_within_transformer(self):
|
139 | 146 | config = ModelArgs(
|
@@ -162,48 +169,57 @@ def test_within_transformer(self):
|
162 | 169 | n_chunks = 3
|
163 | 170 | chunk_len = config.max_seq_len // n_chunks
|
164 | 171 | cache_len = config.max_seq_len - chunk_len
|
165 |
| - mask = torch.zeros(1, chunk_len, cache_len + chunk_len) |
166 |
| - mask[:, :, :cache_len] = float("-inf") |
167 |
| - mask[:, :, cache_len:] = torch.triu( |
168 |
| - torch.full((1, chunk_len, chunk_len), float("-inf")), |
169 |
| - diagonal=1, |
170 |
| - ) |
171 |
| - k_caches = { |
172 |
| - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
173 |
| - 1, cache_len, config.head_dim |
174 |
| - ) |
175 |
| - for layer_id in range(config.n_layers) |
176 |
| - for i in range(config.n_kv_heads) |
177 |
| - } |
178 |
| - v_caches = { |
179 |
| - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
180 |
| - 1, cache_len, config.head_dim |
181 |
| - ) |
182 |
| - for layer_id in range(config.n_layers) |
183 |
| - for i in range(config.n_kv_heads) |
184 |
| - } |
185 |
| - ys = [] |
186 |
| - for i in range(n_chunks): |
187 |
| - y_i, attn_update = static_transformer( |
188 |
| - x[:, i * chunk_len : (i + 1) * chunk_len], |
189 |
| - attn_options=ForwardOptions( |
190 |
| - mask=mask, |
191 |
| - freqs_cos_override=freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
192 |
| - freqs_sin_override=freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
193 |
| - in_cache_state=(k_caches, v_caches), |
194 |
| - out_cache_state=({}, {}), |
195 |
| - ), |
| 172 | + |
| 173 | + def test_with_style(style): |
| 174 | + mask = StaticAttentionMask(chunk_len, cache_len, style=style) |
| 175 | + mask.tensor[:, :, cache_len:] = torch.triu( |
| 176 | + torch.full((1, chunk_len, chunk_len), float("-inf")), |
| 177 | + diagonal=1, |
196 | 178 | )
|
197 |
| - ys.append(y_i) |
198 |
| - mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0 |
199 |
| - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
200 |
| - for cache_id, update in k_cache_updates.items(): |
201 |
| - k_caches[cache_id] = StaticKVCache.apply_update( |
202 |
| - k_caches[cache_id], update |
| 179 | + k_caches = { |
| 180 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 181 | + 1, cache_len, config.head_dim |
203 | 182 | )
|
204 |
| - for cache_id, update in v_cache_updates.items(): |
205 |
| - v_caches[cache_id] = StaticKVCache.apply_update( |
206 |
| - v_caches[cache_id], update |
| 183 | + for layer_id in range(config.n_layers) |
| 184 | + for i in range(config.n_kv_heads) |
| 185 | + } |
| 186 | + v_caches = { |
| 187 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 188 | + 1, cache_len, config.head_dim |
207 | 189 | )
|
208 |
| - |
209 |
| - self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) |
| 190 | + for layer_id in range(config.n_layers) |
| 191 | + for i in range(config.n_kv_heads) |
| 192 | + } |
| 193 | + ys = [] |
| 194 | + for i in range(n_chunks): |
| 195 | + y_i, attn_update = static_transformer( |
| 196 | + x[:, i * chunk_len : (i + 1) * chunk_len], |
| 197 | + attn_options=ForwardOptions( |
| 198 | + mask=mask.tensor, |
| 199 | + freqs_cos_override=freqs_cos[ |
| 200 | + i * chunk_len : (i + 1) * chunk_len |
| 201 | + ], |
| 202 | + freqs_sin_override=freqs_sin[ |
| 203 | + i * chunk_len : (i + 1) * chunk_len |
| 204 | + ], |
| 205 | + in_cache_state=(k_caches, v_caches), |
| 206 | + out_cache_state=({}, {}), |
| 207 | + ), |
| 208 | + ) |
| 209 | + ys.append(y_i) |
| 210 | + mask.unmask(chunk_len) |
| 211 | + k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
| 212 | + if i < n_chunks - 1: |
| 213 | + for cache_id, update in k_cache_updates.items(): |
| 214 | + k_caches[cache_id] = StaticKVCache.apply_update( |
| 215 | + k_caches[cache_id], update, pos=chunk_len * i, style=style |
| 216 | + ) |
| 217 | + for cache_id, update in v_cache_updates.items(): |
| 218 | + v_caches[cache_id] = StaticKVCache.apply_update( |
| 219 | + v_caches[cache_id], update, pos=chunk_len * i, style=style |
| 220 | + ) |
| 221 | + |
| 222 | + self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) |
| 223 | + |
| 224 | + test_with_style("shift_pointer") |
| 225 | + test_with_style("smart_mask") |
0 commit comments