Skip to content

Commit 8a46c77

Browse files
committed
Transform model to be able to use Attention Sink
This PR adds necessary functions for transforming the model to be able to use Attention Sink. Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) [ghstack-poisoned]
1 parent 7baa27d commit 8a46c77

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,13 @@ def build_args_parser() -> argparse.ArgumentParser:
432432
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
433433
)
434434

435+
parser.add_argument(
436+
"--use_attention_sink",
437+
default="4,2044,1024",
438+
type=str,
439+
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>'"
440+
)
441+
435442
parser.add_argument(
436443
"--output_prune_map",
437444
default=None,

examples/models/llama/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,20 @@ def __init__(self, **kwargs):
200200
)
201201

202202
sanitize_checkpoint_from_pre_quantization(checkpoint)
203+
204+
if hasattr(self.args, "use_attention_sink"):
205+
from .source_transformation.sink_attention import (
206+
enable_attention_sink,
207+
)
208+
attention_sink_params = self.args.use_attention_sink.split(",")
209+
assert len(attention_sink_params) == 3
210+
211+
self.model_ = enable_attention_sink(
212+
module=self.model_,
213+
params=model_args,
214+
sink_size=int(attention_sink_params[0]),
215+
window_size=int(attention_sink_params[1]),
216+
eviction_batch_size=int(attention_sink_params[2]))
203217

204218
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
205219
# Because we are using device="meta", tensors do not have memory associated with them

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
# Components for supporting Attention Sink. See
88
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
99

10+
import types
11+
from typing import Optional
12+
1013
import torch
1114

12-
from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
15+
from executorch.examples.models.llama.llama_transformer import (
16+
Attention,
17+
KVCache,
18+
ModelArgs,
19+
Rope,
20+
)
1321
from executorch.examples.models.llama.rope import (
1422
apply_rotary_emb_to_k,
1523
hf_apply_rotary_emb_to_k,
1624
)
25+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1726

1827

1928
class RopeWithAttentionSink(Rope):
@@ -167,3 +176,106 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
167176
)
168177
self.position_shift -= num_to_evict # pyre-ignore [8]
169178
return self.position_shift
179+
180+
181+
def attention_sink_forward(
182+
self,
183+
x: torch.Tensor,
184+
input_pos: Optional[torch.Tensor] = None,
185+
):
186+
assert self.use_kv_cache
187+
assert input_pos is not None
188+
189+
bsz, seqlen, _ = x.shape
190+
191+
# QKV
192+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
193+
# We need view_copy elimination
194+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
195+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
196+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
197+
198+
# Prepare for space in KV cache and get position shift
199+
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)
200+
201+
shifted_position = input_pos + position_shift
202+
203+
# RoPE relative positional embeddings with shifted position in KV cache
204+
q, k = self.rope.forward(q, k, shifted_position)
205+
206+
output = self.SDPA(shifted_position, q, k, v, bsz, seqlen, self.mask)
207+
return self.wo(output)
208+
209+
210+
def _replace_rope(
211+
module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink
212+
):
213+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
214+
return isinstance(child, Rope)
215+
216+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
217+
return rope_with_attention_sink
218+
219+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
220+
221+
222+
def _replace_kv_cache(
223+
module: torch.nn.Module,
224+
rope_with_attention_sink: RopeWithAttentionSink,
225+
sink_size: int,
226+
window_size: int,
227+
eviction_batch_size: int,
228+
):
229+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
230+
return isinstance(child, KVCache)
231+
232+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
233+
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
234+
n_heads=child.n_heads,
235+
head_dim=child.head_dim,
236+
transpose_cache=child.transpose_cache,
237+
enable_dynamic_shape=child.enable_dynamic_shape,
238+
rope=rope_with_attention_sink,
239+
max_batch_size=child.max_batch_size,
240+
window_size=window_size,
241+
sink_size=sink_size,
242+
eviction_batch_size=eviction_batch_size,
243+
dtype=child.k_cache.dtype,
244+
)
245+
return kv_cache_with_attention_sink
246+
247+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
248+
249+
250+
def _replace_attention_forward(module: torch.nn.Module):
251+
for name, child_module in module._modules.items():
252+
if len(list(child_module.children())) > 0: # pyre-ignore [16]
253+
_replace_attention_forward(child_module) # pyre-ignore [6]
254+
255+
if isinstance(child_module, Attention):
256+
module._modules[name].forward = types.MethodType( # pyre-ignore
257+
attention_sink_forward, module._modules[name]
258+
)
259+
260+
261+
def enable_attention_sink(
262+
module: torch.nn.Module,
263+
params: ModelArgs,
264+
sink_size: int = 4,
265+
window_size: int = 2044,
266+
eviction_batch_size: int = 1,
267+
) -> torch.nn.Module:
268+
"""
269+
Transform the model to be able to run inference with Attention Sink.
270+
There mainly three steps:
271+
- Replace Rope with RopeWithAttentionSink
272+
- Replace KVCache with KVCacheWithAttentionSink
273+
- Replace Attention's forward with attention_sink_forward
274+
"""
275+
rope_with_attention_sink = RopeWithAttentionSink(params=params)
276+
_replace_rope(module, rope_with_attention_sink)
277+
_replace_kv_cache(
278+
module, rope_with_attention_sink, sink_size, window_size, eviction_batch_size
279+
)
280+
_replace_attention_forward(module)
281+
return module

0 commit comments

Comments
 (0)