Skip to content

[llama-mm] Enable kv cache for MultiHeadAttention #6798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions extension/llm/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
replace_tile_positional_embedding,
TilePositionalEmbedding,
)
from .attention import MultiHeadAttention, replace_mha_with_inference_mha
from .kv_cache import KVCache

__all__ = [
"TilePositionalEmbedding",
"replace_tile_positional_embedding",
"MultiHeadAttention",
"replace_mha_with_inference_mha",
"KVCache",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torchtune.modules.attention as TorchTuneAttention
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
Expand Down Expand Up @@ -148,7 +149,6 @@ def __init__(
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
q_per_kv=self.num_heads // self.num_kv_heads,
attn_dropout=self.attn_dropout if self.training else 0.0,
is_causal=self.is_causal,
attention_fn=self._attention_call,
Expand Down Expand Up @@ -177,12 +177,13 @@ def setup_cache(
"Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
)
else:
self.kv_cache = KVCache(
self.kv_cache = InferenceKVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=dtype,
transpose_cache=False,
)
self._sdpa.kv_cache = self.kv_cache
self.cache_enabled = True
Expand Down Expand Up @@ -307,7 +308,6 @@ def __init__(
num_kv_heads: int,
num_heads: int,
head_dim: int,
q_per_kv: int,
attn_dropout: float,
is_causal: bool,
attention_fn,
Expand All @@ -317,7 +317,7 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.q_per_kv = q_per_kv
self.q_per_kv = self.num_heads // self.num_kv_heads
self.attn_dropout = attn_dropout
self.is_causal = is_causal
self._attention_fn = attention_fn
Expand All @@ -330,25 +330,25 @@ def forward(
v: torch.Tensor, # [b, s, n_kv, h_d]
bsz: int,
seq_len: int,
mask: torch.Tensor = None,
mask: Optional[_MaskType] = None,
) -> torch.Tensor:
# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# k: [bsz, seq_len, n_kv, 1, h_d]
# v: [bsz, seq_len, n_kv, 1, h_d]
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)

# Expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)
v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)
k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)

# [bsz, s, n_h, h_d]
k = k.reshape(bsz, seq_len, -1, self.head_dim)
v = v.reshape(bsz, seq_len, -1, self.head_dim)
k = k.reshape(bsz, -1, self.num_heads, self.head_dim)
v = v.reshape(bsz, -1, self.num_heads, self.head_dim)

# [bsz, n_h, s, h_d]
q = q.transpose(1, 2)
Expand Down
130 changes: 130 additions & 0 deletions extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torchtune.modules.kv_cache import KVCache as TuneKVCache


class KVCache(TuneKVCache):
"""
An export-friendly KVCache implementation adopted from torchtune KVCache:
https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py
This also takes both transposed and un-transposed KVCache shapes.
Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference.

Args:
batch_size (int): batch size model will be run with
max_seq_len (int): maximum sequence length model will be run with
num_kv_heads (int): number of key/value heads.
head_dim (int): per-attention head embedding dimension
dtype (torch.dtype): dtype for the caches
transpose_cache (bool): whether we transpose(1, 2) for kv cache.
"""

def __init__(
self,
batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
transpose_cache: bool = True,
) -> None:
super().__init__(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
self.transpose_cache = transpose_cache
self.max_seq_len = max_seq_len
if self.transpose_cache:
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
else:
cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim)

self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
)
self.batch_size = batch_size

def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.

Note:
When updating the KV cache, it is assumed that subsequent updates should update key-value
positions in consecutive sequence positions. If you wish to update cache values which have
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.

Example:
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9

Args:
k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]

Returns:
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.

Raises:
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
used during cache setup.
"""
if self.transpose_cache:
bsz, _, seq_len, _ = k_val.shape
else:
bsz, seq_len, _, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
raise ValueError(
f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}"
f", but found new key tensors with batch size {k_val.shape[0]}!"
)

assert (
self.cache_pos[0] + seq_len
) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}"
k_out = self.k_cache
v_out = self.v_cache

if self.transpose_cache:
k_out[:, :, self.cache_pos[:seq_len]] = k_val
v_out[:, :, self.cache_pos[:seq_len]] = v_val
else:
k_out[:, self.cache_pos[:seq_len]] = k_val
v_out[:, self.cache_pos[:seq_len]] = v_val

# forward cache_pos seq_len positions along
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
# an update of seq_len = 5 tokens brings it to
# (5, 6, 7, 8, 9, ...)
# this allows us to track the current position in the cache
# after the last update in a compile-friendly way without any dynamism
# e.g. relying on an int size tracker, or re-creating cache_pos every time
self.cache_pos.add_(seq_len)

return k_out, v_out
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from executorch.exir import EdgeCompileConfig, to_edge

from executorch.extension.llm.modules.mha import (
from executorch.extension.llm.modules.attention import (
MultiHeadAttention as ETMultiHeadAttention,
)
from executorch.runtime import Runtime
Expand Down Expand Up @@ -82,10 +82,12 @@ def setUp(self):
# Common inputs.
seq_len = 10
self.x = torch.randn(1, seq_len, self.embed_dim)
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
self.dynamic_shapes = (
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
)

def test_attention_eager(self):
Expand All @@ -94,25 +96,46 @@ def test_attention_eager(self):

self.assertTrue(torch.allclose(et_res, tt_res))

# TODO: KV cache.
# self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
# self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)

# et_res = self.et_mha(self.x, self.x) # Self attention.
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.

self.assertTrue(torch.allclose(et_res, tt_res))
self.et_mha.reset_cache()
self.tt_mha.reset_cache()

# self.assertTrue(torch.allclose(et_res, tt_res))
et_res = self.et_mha(
self.x, self.x, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))

# test kv cache read. Input pos can be [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
et_res = self.et_mha(
self.x, self.x, input_pos=next_input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, input_pos=next_input_pos
) # Self attention with input pos.
self.assertTrue(torch.allclose(et_res, tt_res))

def test_attention_export(self):
# Self attention.
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs=None,
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)
et_res = et_mha_ep.module()(self.x, self.x)
tt_res = self.tt_mha(self.x, self.x)
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
self.assertTrue(torch.allclose(et_res, tt_res))

# TODO: KV cache.
Expand All @@ -126,7 +149,7 @@ def test_attention_executorch(self):
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs=None,
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)
et_program = to_edge(
Expand All @@ -136,8 +159,8 @@ def test_attention_executorch(self):
runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
et_res = method.execute((self.x, self.x))
tt_res = self.tt_mha(self.x, self.x)
et_res = method.execute((self.x, self.x, self.input_pos))
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))

Expand Down
Loading