Skip to content

Commit 9a03603

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Add llama model to examples (#559)
Summary: Pull Request resolved: #559 Test Plan: Imported from OSS Reviewed By: guangy10 Differential Revision: D49734019 Pulled By: iseeyuan fbshipit-source-id: 293b08e8ae7d0a3823ae1485f2bf635f02123951
1 parent c85b683 commit 9a03603

File tree

7 files changed

+266
-0
lines changed

7 files changed

+266
-0
lines changed

.ci/scripts/test.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ build_cmake_executor_runner() {
5353
}
5454

5555
test_model() {
56+
if [[ "${MODEL_NAME}" == "llama2" ]]; then
57+
cd examples/third-party/llama
58+
pip install -e .
59+
cd ../../..
60+
fi
61+
5662
"${PYTHON_EXECUTABLE}" -m examples.export.export_example --model_name="${MODEL_NAME}"
5763

5864
# Run test model

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
1717
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
19+
"llama2": ("llama2", "Llama2Model"),
1920
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2021
"mv2": ("mobilenet_v2", "MV2Model"),
2122
"mv3": ("mobilenet_v3", "MV3Model"),

examples/models/llama2/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Summary
2+
This example demonstrates how to Export a Llama 2 model in ExecuTorch.
3+
For Llama2, please refer to [the llama's github page](https://github.com/facebookresearch/llama) for details.
4+
Pretrained parameters are not included in this repo. Users are suggested to download them through [the llama's download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
5+
6+
# Notes
7+
1. This example is to show the feasibility of exporting a Llama2 model in ExecuTorch. There is no guarantee for performance.
8+
2. It's targeted to a reasonable size for edge devices. Depending on the model size, the memory usage of exporting the model can be high. TODO: improve memory usage in EXIR emitter.
9+
3. The provided check point, demo_rand_params.pth is a dummy checkpoint with random parameters. It does not provide meaningful results. It's only for the purpose of demonstration and fast iterations.
10+
11+
# Limitations
12+
This example tries to reuse the Python code, with modifications to make it compatible with current ExecuTorch:
13+
1. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. TODO: support complex Tensor data type in ExecuTorch.
14+
2. No KV cache. The current cache implementation in the original Llama2 repo is not supported by ExecuTorch, because ExecuTorch runtime assumes model data attributes being static. TODO: add support of mutable buffers in ExecuTorch.
15+
3. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices.
16+
4. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.
17+
18+
19+
# Instructions:
20+
1. Follow the [tutorial](https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md) to set up ExecuTorch
21+
2. `cd examples/third-party/llama`
22+
3. `pip install -e .`
23+
4. Go back to `executorch` root, run `python3 -m examples.export.export_example --model_name="llama2"`. The exported program, llama2.pte would be saved in current directory
24+
5. Use the `executor_runner` (build instruction in step 1) to load and run llama2.pte, `executor_runner --model_path llama2.pte`

examples/models/llama2/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama2Model
8+
9+
__all__ = [
10+
Llama2Model,
11+
]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512}
1.53 MB
Binary file not shown.

examples/models/llama2/model.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please refer to README.md in the same folder for more information.
8+
9+
10+
import json
11+
import math
12+
from pathlib import Path
13+
from typing import Optional, Tuple
14+
15+
import torch
16+
import torch.nn.functional as F
17+
18+
from examples.models.model_base import EagerModelBase
19+
20+
from llama.model import ModelArgs, repeat_kv, RMSNorm
21+
from torch import nn
22+
23+
24+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
25+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
26+
t = torch.arange(end, device=freqs.device)
27+
freqs = torch.outer(t, freqs).float()
28+
freqs_cos = torch.cos(freqs)
29+
freqs_sin = torch.sin(freqs)
30+
return freqs_cos, freqs_sin
31+
32+
33+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
34+
ndim = x.ndim
35+
assert 0 <= 1 < ndim
36+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
37+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
38+
return freqs_cis.view(shape)
39+
40+
41+
def apply_rotary_emb(
42+
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
43+
) -> Tuple[torch.Tensor, torch.Tensor]:
44+
45+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
46+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
47+
48+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
49+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
50+
51+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
52+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
53+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
54+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
55+
56+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
57+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
58+
59+
return xq_out.type_as(xq), xk_out.type_as(xk)
60+
61+
62+
class Attention(nn.Module):
63+
def __init__(self, args: ModelArgs):
64+
super().__init__()
65+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
66+
assert args.n_heads % self.n_kv_heads == 0
67+
model_parallel_size = 1
68+
self.n_local_heads = args.n_heads // model_parallel_size
69+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
70+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
71+
self.head_dim = args.dim // args.n_heads
72+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
73+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
74+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
75+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
76+
77+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
78+
mask = torch.triu(mask, diagonal=1)
79+
self.register_buffer("mask", mask)
80+
81+
def forward(
82+
self,
83+
x: torch.Tensor,
84+
freqs_cos: torch.Tensor,
85+
freqs_sin: torch.Tensor,
86+
):
87+
bsz, seqlen, _ = x.shape
88+
89+
# QKV
90+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
91+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
92+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
93+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
94+
95+
# RoPE relative positional embeddings
96+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
97+
98+
# grouped multiquery attention: expand out keys and values
99+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
100+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
101+
102+
# make heads into a batch dimension
103+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
104+
xk = xk.transpose(1, 2)
105+
xv = xv.transpose(1, 2)
106+
107+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
108+
assert hasattr(self, "mask")
109+
scores = (
110+
scores + self.mask[:, :, :seqlen, :seqlen]
111+
) # (bs, n_local_heads, seqlen, cache_len + seqlen)
112+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
113+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
114+
115+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
116+
117+
output = self.wo(output)
118+
return output
119+
120+
121+
class FeedForward(nn.Module):
122+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
123+
super().__init__()
124+
hidden_dim = int(2 * hidden_dim / 3)
125+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
126+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
127+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
128+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
129+
130+
def forward(self, x):
131+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
132+
133+
134+
class TransformerBlock(nn.Module):
135+
def __init__(self, layer_id: int, args: ModelArgs):
136+
super().__init__()
137+
self.n_heads = args.n_heads
138+
self.dim = args.dim
139+
self.head_dim = args.dim // args.n_heads
140+
self.attention = Attention(args)
141+
self.feed_forward = FeedForward(
142+
dim=args.dim,
143+
hidden_dim=4 * args.dim,
144+
multiple_of=args.multiple_of,
145+
)
146+
self.layer_id = layer_id
147+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
148+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
149+
150+
def forward(self, x, freqs_cos, freqs_sin):
151+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
152+
out = h + self.feed_forward.forward(self.ffn_norm(h))
153+
return out
154+
155+
156+
class Transformer(nn.Module):
157+
last_loss: Optional[torch.Tensor]
158+
159+
def __init__(self, params: ModelArgs):
160+
super().__init__()
161+
self.params = params
162+
self.vocab_size = params.vocab_size
163+
self.n_layers = params.n_layers
164+
165+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
166+
self.layers = torch.nn.ModuleList()
167+
for layer_id in range(params.n_layers):
168+
self.layers.append(TransformerBlock(layer_id, params))
169+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
170+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
171+
172+
freqs_cos, freqs_sin = precompute_freqs_cis(
173+
self.params.dim // self.params.n_heads, self.params.max_seq_len
174+
)
175+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
176+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
177+
178+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
179+
_bsz, seqlen = tokens.shape
180+
h = self.tok_embeddings(tokens)
181+
freqs_cos = self.freqs_cos[:seqlen]
182+
freqs_sin = self.freqs_sin[:seqlen]
183+
184+
for layer in self.layers:
185+
h = layer(h, freqs_cos, freqs_sin)
186+
# h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug
187+
188+
h = self.norm(h)
189+
190+
logits = self.output(h)
191+
return logits
192+
193+
194+
class Llama2Model(EagerModelBase):
195+
def __init__(self):
196+
ckpt_dir = Path(__file__).absolute().parent
197+
# The example is using a dummy small model with random weights for demo purpose only.
198+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
199+
device = "cpu"
200+
checkpoint = torch.load(
201+
Path(ckpt_dir) / "demo_rand_params.pth", map_location=device
202+
)
203+
with open(Path(ckpt_dir) / "demo_config.json", "r") as f:
204+
params = json.loads(f.read())
205+
max_seq_len = 128
206+
max_batch_size = 1
207+
model_args: ModelArgs = ModelArgs(
208+
max_seq_len=max_seq_len,
209+
max_batch_size=max_batch_size,
210+
**params,
211+
)
212+
self.model_ = Transformer(model_args)
213+
self.model_.load_state_dict(
214+
checkpoint, strict=False
215+
) # self.model_ = Transformer(gptconf)
216+
217+
# @staticmethod
218+
def get_eager_model(self):
219+
return self.model_
220+
221+
@staticmethod
222+
def get_example_inputs():
223+
return (torch.tensor([[1]]),)

0 commit comments

Comments
 (0)