Skip to content

Add llama model to examples (#473) #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .ci/scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ build_cmake_executor_runner() {
}

test_model() {
if [[ "${MODEL_NAME}" == "llama2" ]]; then
cd examples/third-party/llama
pip install -e .
cd ../../..
fi

"${PYTHON_EXECUTABLE}" -m examples.export.export_example --model_name="${MODEL_NAME}"

# Run test model
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
"llama2": ("llama2", "Llama2Model"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
"mv3": ("mobilenet_v3", "MV3Model"),
Expand Down
24 changes: 24 additions & 0 deletions examples/models/llama2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Summary
This example demonstrates how to Export a Llama 2 model in ExecuTorch.
For Llama2, please refer to [the llama's github page](https://github.com/facebookresearch/llama) for details.
Pretrained parameters are not included in this repo. Users are suggested to download them through [the llama's download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).

# Notes
1. This example is to show the feasibility of exporting a Llama2 model in ExecuTorch. There is no guarantee for performance.
2. It's targeted to a reasonable size for edge devices. Depending on the model size, the memory usage of exporting the model can be high. TODO: improve memory usage in EXIR emitter.
3. The provided check point, demo_rand_params.pth is a dummy checkpoint with random parameters. It does not provide meaningful results. It's only for the purpose of demonstration and fast iterations.

# Limitations
This example tries to reuse the Python code, with modifications to make it compatible with current ExecuTorch:
1. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. TODO: support complex Tensor data type in ExecuTorch.
2. No KV cache. The current cache implementation in the original Llama2 repo is not supported by ExecuTorch, because ExecuTorch runtime assumes model data attributes being static. TODO: add support of mutable buffers in ExecuTorch.
3. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices.
4. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.


# Instructions:
1. Follow the [tutorial](https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md) to set up ExecuTorch
2. `cd examples/third-party/llama`
3. `pip install -e .`
4. Go back to `executorch` root, run `python3 -m examples.export.export_example --model_name="llama2"`. The exported program, llama2.pte would be saved in current directory
5. Use the `executor_runner` (build instruction in step 1) to load and run llama2.pte, `executor_runner --model_path llama2.pte`
11 changes: 11 additions & 0 deletions examples/models/llama2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import Llama2Model

__all__ = [
Llama2Model,
]
1 change: 1 addition & 0 deletions examples/models/llama2/demo_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512}
Binary file added examples/models/llama2/demo_rand_params.pth
Binary file not shown.
223 changes: 223 additions & 0 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Please refer to README.md in the same folder for more information.


import json
import math
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from examples.models.model_base import EagerModelBase

from llama.model import ModelArgs, repeat_kv, RMSNorm
from torch import nn


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)


def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
bsz, seqlen, _ = x.shape

# QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, "mask")
scores = (
scores + self.mask[:, :, :seqlen, :seqlen]
) # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)
return output


class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x, freqs_cos, freqs_sin):
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out


class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]

def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

freqs_cos, freqs_sin = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

def forward(self, tokens: torch.Tensor) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]

for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug

h = self.norm(h)

logits = self.output(h)
return logits


class Llama2Model(EagerModelBase):
def __init__(self):
ckpt_dir = Path(__file__).absolute().parent
# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
device = "cpu"
checkpoint = torch.load(
Path(ckpt_dir) / "demo_rand_params.pth", map_location=device
)
with open(Path(ckpt_dir) / "demo_config.json", "r") as f:
params = json.loads(f.read())
max_seq_len = 128
max_batch_size = 1
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
self.model_ = Transformer(model_args)
self.model_.load_state_dict(
checkpoint, strict=False
) # self.model_ = Transformer(gptconf)

# @staticmethod
def get_eager_model(self):
return self.model_

@staticmethod
def get_example_inputs():
return (torch.tensor([[1]]),)