|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | +from typing import Callable, Tuple |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from executorch.exir import EdgeCompileConfig, to_edge |
| 13 | +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache |
| 14 | +from executorch.runtime import Runtime |
| 15 | +from torch.testing import assert_close |
| 16 | +from torchtune.modules.kv_cache import KVCache |
| 17 | + |
| 18 | + |
| 19 | +def generate_cache_inputs( |
| 20 | + seq_len: int, |
| 21 | + batch_size: int = 1, |
| 22 | + num_kv_heads: int = 64, |
| 23 | + head_dim: int = 8, |
| 24 | +) -> Tuple[torch.Tensor, ...]: |
| 25 | + """Helper to generate k_val and v_val for both et and tt caches.""" |
| 26 | + k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) |
| 27 | + v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) |
| 28 | + |
| 29 | + # For torchtune, the kv cache takes in transposed k and v. |
| 30 | + k_val_trans = k_val.transpose(1, 2) |
| 31 | + v_val_trans = v_val.transpose(1, 2) |
| 32 | + |
| 33 | + return (k_val, v_val, k_val_trans, v_val_trans) |
| 34 | + |
| 35 | + |
| 36 | +class KVCacheTest(unittest.TestCase): |
| 37 | + def setUp(self): |
| 38 | + self.batch_size = 1 |
| 39 | + self.max_seq_len = 10 |
| 40 | + self.num_kv_heads = 1 # For testing purposes, usually this is 64. |
| 41 | + self.head_dim = 8 |
| 42 | + self.dtype = torch.float |
| 43 | + |
| 44 | + self.tt_kv_cache = KVCache( |
| 45 | + batch_size=self.batch_size, |
| 46 | + max_seq_len=self.max_seq_len, |
| 47 | + num_kv_heads=self.num_kv_heads, |
| 48 | + head_dim=self.head_dim, |
| 49 | + dtype=self.dtype, |
| 50 | + ) |
| 51 | + self.et_kv_cache = InferenceKVCache( |
| 52 | + batch_size=self.batch_size, |
| 53 | + max_seq_len=self.max_seq_len, |
| 54 | + num_kv_heads=self.num_kv_heads, |
| 55 | + head_dim=self.head_dim, |
| 56 | + dtype=self.dtype, |
| 57 | + transpose_cache=False, |
| 58 | + ) |
| 59 | + |
| 60 | + def _test_kv_cache(self, et_cache_module: Callable): |
| 61 | + """ |
| 62 | + Given an executorch kv cache anywhere along the export chain, compare it's results |
| 63 | + against torchtune and run basic tests. |
| 64 | + """ |
| 65 | + prefill_seq_len = 3 |
| 66 | + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( |
| 67 | + prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim |
| 68 | + ) |
| 69 | + |
| 70 | + et_res = et_cache_module(k_val, v_val) |
| 71 | + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) |
| 72 | + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) |
| 73 | + |
| 74 | + # Check torchtune matches executorch. |
| 75 | + assert_close(et_res, tt_res_transposed) |
| 76 | + |
| 77 | + # Check the values are correct, all rows in the seq_len dim should be |
| 78 | + # filled with 1s up to and including the 3rd. |
| 79 | + et_k_cache = et_res[0] |
| 80 | + for i in range(prefill_seq_len): |
| 81 | + self.assertTrue(et_k_cache[0][i][0][0] == 1) |
| 82 | + self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) |
| 83 | + |
| 84 | + """Case 2: Token-by-token (seq_len = 0)""" |
| 85 | + seq_len = 1 |
| 86 | + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( |
| 87 | + seq_len, self.batch_size, self.num_kv_heads, self.head_dim |
| 88 | + ) |
| 89 | + |
| 90 | + et_res = et_cache_module(k_val, v_val) |
| 91 | + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) |
| 92 | + |
| 93 | + # Check torchtune matches executorch. |
| 94 | + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) |
| 95 | + assert_close(tt_res_transposed, et_res) |
| 96 | + |
| 97 | + # All rows should be filled with 1s up to 3 + 1th row. |
| 98 | + et_k_cache = et_res[0] |
| 99 | + for i in range(prefill_seq_len + 1): |
| 100 | + self.assertTrue(et_k_cache[0][i][0][0] == 1) |
| 101 | + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) |
| 102 | + |
| 103 | + def export_kv_cache( |
| 104 | + self, |
| 105 | + kv_cache: torch.nn.Module, |
| 106 | + ) -> torch.export.ExportedProgram: |
| 107 | + # Wrapper since torch.export only exports forward(). |
| 108 | + class EtCacheWrapper(torch.nn.Module): |
| 109 | + def __init__(self, kv_cache: torch.nn.Module): |
| 110 | + super().__init__() |
| 111 | + self.kv_cache = kv_cache |
| 112 | + |
| 113 | + def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): |
| 114 | + return self.kv_cache.update(k_val, v_val) |
| 115 | + |
| 116 | + dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) |
| 117 | + exported_kv_cache = torch.export.export( |
| 118 | + EtCacheWrapper(self.et_kv_cache), |
| 119 | + ( |
| 120 | + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), |
| 121 | + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), |
| 122 | + ), # 3 as example prefill seq_len. |
| 123 | + dynamic_shapes={ |
| 124 | + "k_val": { |
| 125 | + 0: torch.export.Dim.STATIC, |
| 126 | + 1: dim, |
| 127 | + 2: torch.export.Dim.STATIC, |
| 128 | + 3: torch.export.Dim.STATIC, |
| 129 | + }, |
| 130 | + "v_val": { |
| 131 | + 0: torch.export.Dim.STATIC, |
| 132 | + 1: dim, |
| 133 | + 2: torch.export.Dim.STATIC, |
| 134 | + 3: torch.export.Dim.STATIC, |
| 135 | + }, |
| 136 | + }, |
| 137 | + ) |
| 138 | + return exported_kv_cache |
| 139 | + |
| 140 | + def test_kv_cache_eager(self): |
| 141 | + self._test_kv_cache(self.et_kv_cache.update) |
| 142 | + |
| 143 | + def test_kv_cache_export(self): |
| 144 | + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) |
| 145 | + self._test_kv_cache(exported_kv_cache.module()) |
| 146 | + |
| 147 | + def test_kv_cache_edge(self): |
| 148 | + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) |
| 149 | + edge_program = to_edge( |
| 150 | + exported_kv_cache, |
| 151 | + compile_config=EdgeCompileConfig( |
| 152 | + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], |
| 153 | + _check_ir_validity=False, |
| 154 | + ), |
| 155 | + ) |
| 156 | + self._test_kv_cache(edge_program._edge_programs["forward"].module()) |
| 157 | + |
| 158 | + def test_kv_cache_executorch(self): |
| 159 | + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) |
| 160 | + edge_program = to_edge( |
| 161 | + exported_kv_cache, |
| 162 | + compile_config=EdgeCompileConfig( |
| 163 | + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], |
| 164 | + _check_ir_validity=False, |
| 165 | + ), |
| 166 | + ) |
| 167 | + et_program = edge_program.to_executorch() |
| 168 | + runtime = Runtime.get() |
| 169 | + program = runtime.load_program(et_program.buffer) |
| 170 | + method = program.load_method("forward") |
| 171 | + |
| 172 | + # Since method.execute expects a tuple of args. |
| 173 | + def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: |
| 174 | + return method.execute((k_val, v_val)) |
| 175 | + |
| 176 | + self._test_kv_cache(wrapped_callable) |
0 commit comments