Skip to content

Commit aa67cd9

Browse files
Transform model to be able to use Attention Sink
Pull Request resolved: #6700 This PR adds necessary functions for transforming the model to be able to use Attention Sink. ghstack-source-id: 256108077 @exported-using-ghexport Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) Co-authored-by: Lunwen He <[email protected]>
1 parent e0eec82 commit aa67cd9

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser:
448448
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
449449
)
450450

451+
parser.add_argument(
452+
"--use_attention_sink",
453+
default=None,
454+
type=str,
455+
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
456+
)
457+
451458
parser.add_argument(
452459
"--output_prune_map",
453460
default=None,

examples/models/llama/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,25 @@ def __init__(self, **kwargs):
201201

202202
sanitize_checkpoint_from_pre_quantization(checkpoint)
203203

204+
if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
205+
from .source_transformation.attention_sink import enable_attention_sink
206+
207+
attention_sink_params = self.args.use_attention_sink.split(",")
208+
assert len(attention_sink_params) == 3
209+
sink_size = int(attention_sink_params[0])
210+
window_size = int(attention_sink_params[1])
211+
eviction_batch_size = int(attention_sink_params[2])
212+
213+
assert self.args.max_seq_length == sink_size + window_size
214+
215+
self.model_ = enable_attention_sink(
216+
module=self.model_,
217+
params=model_args,
218+
sink_size=sink_size,
219+
window_size=window_size,
220+
eviction_batch_size=eviction_batch_size,
221+
)
222+
204223
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
205224
# Because we are using device="meta", tensors do not have memory associated with them
206225
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,22 @@
77
# Components for supporting Attention Sink. See
88
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
99

10+
import types
1011
from typing import Optional
1112

1213
import torch
1314

14-
from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
15+
from executorch.examples.models.llama.llama_transformer import (
16+
Attention,
17+
KVCache,
18+
ModelArgs,
19+
Rope,
20+
)
1521
from executorch.examples.models.llama.rope import (
1622
apply_rotary_emb_to_k,
1723
hf_apply_rotary_emb_to_k,
1824
)
25+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1926

2027

2128
class RopeWithAttentionSink(Rope):
@@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
206213
)
207214
self.position_shift -= num_to_evict # pyre-ignore [8]
208215
return self.position_shift
216+
217+
218+
def attention_sink_forward(
219+
self,
220+
x: torch.Tensor,
221+
freqs_cos: torch.Tensor,
222+
freqs_sin: torch.Tensor,
223+
input_pos: Optional[torch.Tensor] = None,
224+
):
225+
assert self.use_kv_cache
226+
assert input_pos is not None
227+
228+
bsz, seqlen, _ = x.shape
229+
230+
# QKV
231+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
232+
# We need view_copy elimination
233+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
234+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
235+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
236+
237+
# Prepare for space in KV cache and get position shift
238+
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)
239+
240+
# RoPE relative positional embeddings with shifted position in KV cache
241+
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
242+
243+
output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
244+
return self.wo(output)
245+
246+
247+
def _replace_rope(
248+
module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink
249+
):
250+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
251+
return isinstance(child, Rope)
252+
253+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
254+
return rope_with_attention_sink
255+
256+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
257+
258+
259+
def _replace_attention(
260+
module: torch.nn.Module,
261+
rope_with_attention_sink: RopeWithAttentionSink,
262+
sink_size: int,
263+
window_size: int,
264+
eviction_batch_size: int,
265+
):
266+
for _, child_module in module._modules.items():
267+
if len(list(child_module.children())) > 0: # pyre-ignore [16]
268+
_replace_attention(
269+
module=child_module, # pyre-ignore [6]
270+
rope_with_attention_sink=rope_with_attention_sink,
271+
sink_size=sink_size,
272+
window_size=window_size,
273+
eviction_batch_size=eviction_batch_size,
274+
)
275+
276+
if isinstance(child_module, Attention):
277+
kv_cache = child_module.kv_cache
278+
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
279+
n_heads=kv_cache.n_heads,
280+
head_dim=kv_cache.head_dim,
281+
transpose_cache=kv_cache.transpose_cache,
282+
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
283+
rope=rope_with_attention_sink,
284+
max_batch_size=kv_cache.max_batch_size,
285+
window_size=window_size,
286+
sink_size=sink_size,
287+
eviction_batch_size=eviction_batch_size,
288+
dtype=kv_cache.k_cache.dtype,
289+
)
290+
child_module.kv_cache = kv_cache_with_attention_sink
291+
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
292+
child_module.forward = types.MethodType( # pyre-ignore
293+
attention_sink_forward, child_module
294+
)
295+
296+
297+
def enable_attention_sink(
298+
module: torch.nn.Module,
299+
params: ModelArgs,
300+
sink_size: int,
301+
window_size: int,
302+
eviction_batch_size: int,
303+
) -> torch.nn.Module:
304+
"""
305+
Transform the model to be able to run inference with Attention Sink.
306+
There mainly three steps:
307+
- Replace Rope with RopeWithAttentionSink
308+
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
309+
"""
310+
rope_with_attention_sink = RopeWithAttentionSink(
311+
params=params,
312+
window_size=window_size,
313+
sink_size=sink_size,
314+
eviction_batch_size=eviction_batch_size,
315+
)
316+
_replace_rope(module, rope_with_attention_sink)
317+
_replace_attention(
318+
module=module,
319+
rope_with_attention_sink=rope_with_attention_sink,
320+
sink_size=sink_size,
321+
window_size=window_size,
322+
eviction_batch_size=eviction_batch_size,
323+
)
324+
return module

0 commit comments

Comments
 (0)