Skip to content

Decouple custom ops in llama_transformer.py Part 2/N (#3007) #3061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib",
],
)

Expand Down Expand Up @@ -85,6 +84,7 @@ runtime.python_library(
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
"//executorch/examples/portable:utils",
"//executorch/exir:lib",
"//executorch/sdk/etrecord:etrecord",
Expand Down
61 changes: 58 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
XnnpackDynamicallyQuantizedPartitioner,
)

from executorch.examples.models.llama2.llama_transformer import Transformer
from executorch.examples.models.llama2.llama_transformer import (
KVCache,
SDPA,
Transformer,
)
from executorch.exir.backend.backend_details import CompileSpec

from executorch.sdk.etrecord import generate_etrecord
Expand Down Expand Up @@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis(
return module


class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: KVCache,
mask,
dim: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.dim = dim

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
):
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
return output.view(bsz, seqlen, self.dim)


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACustom(child.kv_cache, child.mask, child.dim),
)
else:
_replace_sdpa_with_custom_op(child)


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa

_replace_sdpa_with_custom_op(module)
return module


def quantize(
model: torch.nn.Module,
qmode: str,
Expand Down Expand Up @@ -483,8 +539,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
pass
# TODO: Next diff transforms.append()
transforms.append(replace_sdpa_with_custom_op)

return (
load_llama_model(
Expand Down
53 changes: 0 additions & 53 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,12 @@ def __init__(
self,
kv_cache: KVCache,
mask,
use_sdpa_with_kv_cache_op: bool,
dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
self.dim = dim
self.n_rep = n_rep

Expand All @@ -217,56 +215,6 @@ def forward(
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
if not self.use_sdpa_with_kv_cache_op:
return self._forward_default(
input_pos,
q,
k,
v,
bsz,
seqlen,
)
else:
return self._forward_custom(
input_pos,
q,
k,
v,
bsz,
seqlen,
)

def _forward_custom(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
):
from .custom_ops import sdpa_with_kv_cache # noqa

output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
return output.view(bsz, seqlen, self.dim)

def _forward_default(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
Expand Down Expand Up @@ -325,7 +273,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.SDPA = SDPA(
self.kv_cache,
self.mask,
args.use_sdpa_with_kv_cache_op,
self.dim,
self.n_rep,
)
Expand Down