Skip to content

Commit aac90a0

Browse files
committed
Add tests that localize the prefill issue to the kv cache
1 parent 343aa0c commit aac90a0

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Callable, Tuple
9+
10+
import torch
11+
12+
from executorch.exir import EdgeCompileConfig, to_edge
13+
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
14+
from executorch.runtime import Runtime
15+
from torch.testing import assert_close
16+
from torchtune.modules.kv_cache import KVCache
17+
18+
19+
def generate_cache_inputs(
20+
seq_len: int,
21+
batch_size: int = 1,
22+
num_kv_heads: int = 64,
23+
head_dim: int = 8,
24+
) -> Tuple[torch.Tensor, ...]:
25+
"""Helper to generate k_val and v_val for both et and tt caches."""
26+
k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim)
27+
v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim)
28+
29+
# For torchtune, the kv cache takes in transposed k and v.
30+
k_val_trans = k_val.transpose(1, 2)
31+
v_val_trans = v_val.transpose(1, 2)
32+
33+
return (k_val, v_val, k_val_trans, v_val_trans)
34+
35+
36+
class KVCacheTest(unittest.TestCase):
37+
def setUp(self):
38+
self.batch_size = 1
39+
self.max_seq_len = 10
40+
self.num_kv_heads = 1 # For testing purposes, usually this is 64.
41+
self.head_dim = 8
42+
self.dtype = torch.float
43+
44+
self.tt_kv_cache = KVCache(
45+
batch_size=self.batch_size,
46+
max_seq_len=self.max_seq_len,
47+
num_kv_heads=self.num_kv_heads,
48+
head_dim=self.head_dim,
49+
dtype=self.dtype,
50+
)
51+
self.et_kv_cache = InferenceKVCache(
52+
batch_size=self.batch_size,
53+
max_seq_len=self.max_seq_len,
54+
num_kv_heads=self.num_kv_heads,
55+
head_dim=self.head_dim,
56+
dtype=self.dtype,
57+
transpose_cache=False,
58+
)
59+
60+
def _test_kv_cache(self, et_cache_module: Callable):
61+
"""
62+
Given an executorch kv cache anywhere along the export chain, compare it's results
63+
against torchtune and run basic tests.
64+
"""
65+
prefill_seq_len = 3
66+
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs(
67+
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim
68+
)
69+
70+
et_res = et_cache_module(k_val, v_val)
71+
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
72+
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
73+
74+
# Check torchtune matches executorch.
75+
assert_close(et_res, tt_res_transposed)
76+
77+
# Check the values are correct, all rows in the seq_len dim should be
78+
# filled with 1s up to and including the 3rd.
79+
et_k_cache = et_res[0]
80+
for i in range(prefill_seq_len):
81+
self.assertTrue(et_k_cache[0][i][0][0] == 1)
82+
self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0)
83+
84+
"""Case 2: Token-by-token (seq_len = 0)"""
85+
seq_len = 1
86+
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs(
87+
seq_len, self.batch_size, self.num_kv_heads, self.head_dim
88+
)
89+
90+
et_res = et_cache_module(k_val, v_val)
91+
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
92+
93+
# Check torchtune matches executorch.
94+
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
95+
assert_close(tt_res_transposed, et_res)
96+
97+
# All rows should be filled with 1s up to 3 + 1th row.
98+
et_k_cache = et_res[0]
99+
for i in range(prefill_seq_len + 1):
100+
self.assertTrue(et_k_cache[0][i][0][0] == 1)
101+
self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0)
102+
103+
def export_kv_cache(
104+
self,
105+
kv_cache: torch.nn.Module,
106+
) -> torch.export.ExportedProgram:
107+
# Wrapper since torch.export only exports forward().
108+
class EtCacheWrapper(torch.nn.Module):
109+
def __init__(self, kv_cache: torch.nn.Module):
110+
super().__init__()
111+
self.kv_cache = kv_cache
112+
113+
def forward(self, k_val: torch.Tensor, v_val: torch.Tensor):
114+
return self.kv_cache.update(k_val, v_val)
115+
116+
dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len)
117+
exported_kv_cache = torch.export.export(
118+
EtCacheWrapper(self.et_kv_cache),
119+
(
120+
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim),
121+
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim),
122+
), # 3 as example prefill seq_len.
123+
dynamic_shapes={
124+
"k_val": {
125+
0: torch.export.Dim.STATIC,
126+
1: dim,
127+
2: torch.export.Dim.STATIC,
128+
3: torch.export.Dim.STATIC,
129+
},
130+
"v_val": {
131+
0: torch.export.Dim.STATIC,
132+
1: dim,
133+
2: torch.export.Dim.STATIC,
134+
3: torch.export.Dim.STATIC,
135+
},
136+
},
137+
)
138+
return exported_kv_cache
139+
140+
def test_kv_cache_eager(self):
141+
self._test_kv_cache(self.et_kv_cache.update)
142+
143+
def test_kv_cache_export(self):
144+
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
145+
self._test_kv_cache(exported_kv_cache.module())
146+
147+
def test_kv_cache_edge(self):
148+
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
149+
edge_program = to_edge(
150+
exported_kv_cache,
151+
compile_config=EdgeCompileConfig(
152+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
153+
_check_ir_validity=False,
154+
),
155+
)
156+
self._test_kv_cache(edge_program._edge_programs["forward"].module())
157+
158+
def test_kv_cache_executorch(self):
159+
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
160+
edge_program = to_edge(
161+
exported_kv_cache,
162+
compile_config=EdgeCompileConfig(
163+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
164+
_check_ir_validity=False,
165+
),
166+
)
167+
et_program = edge_program.to_executorch()
168+
runtime = Runtime.get()
169+
program = runtime.load_program(et_program.buffer)
170+
method = program.load_method("forward")
171+
172+
# Since method.execute expects a tuple of args.
173+
def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
174+
return method.execute((k_val, v_val))
175+
176+
self._test_kv_cache(wrapped_callable)

0 commit comments

Comments
 (0)