Skip to content

Commit 9cd57d6

Browse files
[llama-mm] Enable kv cache for MultiHeadAttention (#6798)
Summary: Change `MultiHeadAttention` in `extension/llm/modules` to support KV cache. Only enable eager but not export yet. Test Plan: Unit test Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ded3830 Pull Request resolved: #6793 Co-authored-by: Mengwei Liu <[email protected]>
1 parent 4b7a60f commit 9cd57d6

File tree

4 files changed

+183
-25
lines changed

4 files changed

+183
-25
lines changed

extension/llm/modules/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
replace_tile_positional_embedding,
99
TilePositionalEmbedding,
1010
)
11+
from .attention import MultiHeadAttention, replace_mha_with_inference_mha
12+
from .kv_cache import KVCache
1113

1214
__all__ = [
1315
"TilePositionalEmbedding",
1416
"replace_tile_positional_embedding",
17+
"MultiHeadAttention",
18+
"replace_mha_with_inference_mha",
19+
"KVCache",
1520
]

extension/llm/modules/mha.py renamed to extension/llm/modules/attention.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torchtune.modules.attention as TorchTuneAttention
12+
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
1213
from torch import nn
1314
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1415
from torchtune.modules.kv_cache import KVCache
@@ -148,7 +149,6 @@ def __init__(
148149
num_kv_heads=self.num_kv_heads,
149150
num_heads=self.num_heads,
150151
head_dim=self.head_dim,
151-
q_per_kv=self.num_heads // self.num_kv_heads,
152152
attn_dropout=self.attn_dropout if self.training else 0.0,
153153
is_causal=self.is_causal,
154154
attention_fn=self._attention_call,
@@ -177,12 +177,13 @@ def setup_cache(
177177
"Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
178178
)
179179
else:
180-
self.kv_cache = KVCache(
180+
self.kv_cache = InferenceKVCache(
181181
batch_size=batch_size,
182182
max_seq_len=max_seq_len,
183-
num_heads=self.num_heads,
183+
num_kv_heads=self.num_kv_heads,
184184
head_dim=self.head_dim,
185185
dtype=dtype,
186+
transpose_cache=False,
186187
)
187188
self._sdpa.kv_cache = self.kv_cache
188189
self.cache_enabled = True
@@ -307,7 +308,6 @@ def __init__(
307308
num_kv_heads: int,
308309
num_heads: int,
309310
head_dim: int,
310-
q_per_kv: int,
311311
attn_dropout: float,
312312
is_causal: bool,
313313
attention_fn,
@@ -317,7 +317,7 @@ def __init__(
317317
self.num_kv_heads = num_kv_heads
318318
self.num_heads = num_heads
319319
self.head_dim = head_dim
320-
self.q_per_kv = q_per_kv
320+
self.q_per_kv = self.num_heads // self.num_kv_heads
321321
self.attn_dropout = attn_dropout
322322
self.is_causal = is_causal
323323
self._attention_fn = attention_fn
@@ -330,25 +330,25 @@ def forward(
330330
v: torch.Tensor, # [b, s, n_kv, h_d]
331331
bsz: int,
332332
seq_len: int,
333-
mask: torch.Tensor = None,
333+
mask: Optional[_MaskType] = None,
334334
) -> torch.Tensor:
335335
# View + expand + reshape bring num_kv_heads to num_heads for k and v
336336
# to match q.
337337

338338
# k: [bsz, seq_len, n_kv, 1, h_d]
339339
# v: [bsz, seq_len, n_kv, 1, h_d]
340-
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
341-
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
340+
k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
341+
v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
342342

343343
# Expand the key and value tensors to have the same shape
344344
# as the query tensor by copying values across the relevant dim
345345
if self.num_heads != self.num_kv_heads:
346-
k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)
347-
v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)
346+
k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
347+
v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
348348

349349
# [bsz, s, n_h, h_d]
350-
k = k.reshape(bsz, seq_len, -1, self.head_dim)
351-
v = v.reshape(bsz, seq_len, -1, self.head_dim)
350+
k = k.reshape(bsz, -1, self.num_heads, self.head_dim)
351+
v = v.reshape(bsz, -1, self.num_heads, self.head_dim)
352352

353353
# [bsz, n_h, s, h_d]
354354
q = q.transpose(1, 2)

extension/llm/modules/kv_cache.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from torchtune.modules.kv_cache import KVCache as TuneKVCache
11+
12+
13+
class KVCache(TuneKVCache):
14+
"""
15+
An export-friendly KVCache implementation adopted from torchtune KVCache:
16+
https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py
17+
This also takes both transposed and un-transposed KVCache shapes.
18+
Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference.
19+
20+
Args:
21+
batch_size (int): batch size model will be run with
22+
max_seq_len (int): maximum sequence length model will be run with
23+
num_kv_heads (int): number of key/value heads.
24+
head_dim (int): per-attention head embedding dimension
25+
dtype (torch.dtype): dtype for the caches
26+
transpose_cache (bool): whether we transpose(1, 2) for kv cache.
27+
"""
28+
29+
def __init__(
30+
self,
31+
batch_size: int,
32+
max_seq_len: int,
33+
num_kv_heads: int,
34+
head_dim: int,
35+
dtype: torch.dtype,
36+
transpose_cache: bool = True,
37+
) -> None:
38+
super().__init__(
39+
batch_size=batch_size,
40+
max_seq_len=max_seq_len,
41+
num_kv_heads=num_kv_heads,
42+
head_dim=head_dim,
43+
dtype=dtype,
44+
)
45+
self.transpose_cache = transpose_cache
46+
self.max_seq_len = max_seq_len
47+
if self.transpose_cache:
48+
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
49+
else:
50+
cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim)
51+
52+
self.register_buffer(
53+
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
54+
)
55+
self.register_buffer(
56+
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
57+
)
58+
self.register_buffer(
59+
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
60+
)
61+
self.batch_size = batch_size
62+
63+
def update(
64+
self, k_val: torch.Tensor, v_val: torch.Tensor
65+
) -> Tuple[torch.Tensor, torch.Tensor]:
66+
"""Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.
67+
68+
Note:
69+
When updating the KV cache, it is assumed that subsequent updates should update key-value
70+
positions in consecutive sequence positions. If you wish to update cache values which have
71+
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
72+
73+
Example:
74+
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
75+
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
76+
>>> cache.update(keys, values)
77+
>>> # now positions 0 through 7 are filled
78+
>>> cache.size
79+
>>> 8
80+
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
81+
>>> cache.update(keys, values)
82+
>>> # this will fill at position 8
83+
>>> cache.size
84+
>>> 9
85+
86+
Args:
87+
k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
88+
v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]
89+
90+
Returns:
91+
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
92+
93+
Raises:
94+
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
95+
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
96+
used during cache setup.
97+
"""
98+
if self.transpose_cache:
99+
bsz, _, seq_len, _ = k_val.shape
100+
else:
101+
bsz, seq_len, _, _ = k_val.shape
102+
if bsz > self.k_cache.shape[0]:
103+
raise ValueError(
104+
f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}"
105+
f", but found new key tensors with batch size {k_val.shape[0]}!"
106+
)
107+
108+
assert (
109+
self.cache_pos[0] + seq_len
110+
) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}"
111+
k_out = self.k_cache
112+
v_out = self.v_cache
113+
114+
if self.transpose_cache:
115+
k_out[:, :, self.cache_pos[:seq_len]] = k_val
116+
v_out[:, :, self.cache_pos[:seq_len]] = v_val
117+
else:
118+
k_out[:, self.cache_pos[:seq_len]] = k_val
119+
v_out[:, self.cache_pos[:seq_len]] = v_val
120+
121+
# forward cache_pos seq_len positions along
122+
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
123+
# an update of seq_len = 5 tokens brings it to
124+
# (5, 6, 7, 8, 9, ...)
125+
# this allows us to track the current position in the cache
126+
# after the last update in a compile-friendly way without any dynamism
127+
# e.g. relying on an int size tracker, or re-creating cache_pos every time
128+
self.cache_pos.add_(seq_len)
129+
130+
return k_out, v_out

extension/llm/modules/test/test_mha.py renamed to extension/llm/modules/test/test_attention.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from executorch.exir import EdgeCompileConfig, to_edge
1111

12-
from executorch.extension.llm.modules.mha import (
12+
from executorch.extension.llm.modules.attention import (
1313
MultiHeadAttention as ETMultiHeadAttention,
1414
)
1515
from executorch.runtime import Runtime
@@ -82,10 +82,12 @@ def setUp(self):
8282
# Common inputs.
8383
seq_len = 10
8484
self.x = torch.randn(1, seq_len, self.embed_dim)
85+
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
8586
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
8687
self.dynamic_shapes = (
8788
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
8889
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
90+
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
8991
)
9092

9193
def test_attention_eager(self):
@@ -94,25 +96,46 @@ def test_attention_eager(self):
9496

9597
self.assertTrue(torch.allclose(et_res, tt_res))
9698

97-
# TODO: KV cache.
98-
# self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
99-
# self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
99+
# test with kv cache
100+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
101+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
100102

101-
# et_res = self.et_mha(self.x, self.x) # Self attention.
102-
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
103+
et_res = self.et_mha(self.x, self.x) # Self attention.
104+
tt_res = self.tt_mha(self.x, self.x) # Self attention.
105+
106+
self.assertTrue(torch.allclose(et_res, tt_res))
107+
self.et_mha.reset_cache()
108+
self.tt_mha.reset_cache()
103109

104-
# self.assertTrue(torch.allclose(et_res, tt_res))
110+
et_res = self.et_mha(
111+
self.x, self.x, input_pos=self.input_pos
112+
) # Self attention with input pos.
113+
tt_res = self.tt_mha(
114+
self.x, self.x, input_pos=self.input_pos
115+
) # Self attention with input pos.
116+
117+
self.assertTrue(torch.allclose(et_res, tt_res))
118+
119+
# test kv cache read. Input pos can be [10, 11, ..., 19]
120+
next_input_pos = torch.arange(10, 20).unsqueeze(0)
121+
et_res = self.et_mha(
122+
self.x, self.x, input_pos=next_input_pos
123+
) # Self attention with input pos.
124+
tt_res = self.tt_mha(
125+
self.x, self.x, input_pos=next_input_pos
126+
) # Self attention with input pos.
127+
self.assertTrue(torch.allclose(et_res, tt_res))
105128

106129
def test_attention_export(self):
107130
# Self attention.
108131
et_mha_ep = torch.export.export(
109132
self.et_mha,
110133
(self.x, self.x),
111-
kwargs=None,
134+
kwargs={"input_pos": self.input_pos},
112135
dynamic_shapes=self.dynamic_shapes,
113136
)
114-
et_res = et_mha_ep.module()(self.x, self.x)
115-
tt_res = self.tt_mha(self.x, self.x)
137+
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
138+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
116139
self.assertTrue(torch.allclose(et_res, tt_res))
117140

118141
# TODO: KV cache.
@@ -126,7 +149,7 @@ def test_attention_executorch(self):
126149
et_mha_ep = torch.export.export(
127150
self.et_mha,
128151
(self.x, self.x),
129-
kwargs=None,
152+
kwargs={"input_pos": self.input_pos},
130153
dynamic_shapes=self.dynamic_shapes,
131154
)
132155
et_program = to_edge(
@@ -136,8 +159,8 @@ def test_attention_executorch(self):
136159
runtime = Runtime.get()
137160
program = runtime.load_program(et_program.buffer)
138161
method = program.load_method("forward")
139-
et_res = method.execute((self.x, self.x))
140-
tt_res = self.tt_mha(self.x, self.x)
162+
et_res = method.execute((self.x, self.x, self.input_pos))
163+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
141164

142165
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
143166

0 commit comments

Comments
 (0)