Skip to content

Commit e6ed40f

Browse files
author
Martin Yuan
committed
Add llama model to examples
ghstack-source-id: 4ee0956 Pull Request resolved: #473
1 parent c386d4c commit e6ed40f

File tree

3 files changed

+258
-0
lines changed

3 files changed

+258
-0
lines changed

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
1717
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
19+
"llama": ("llama", "LlamaModel"),
1920
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2021
"mv2": ("mobilenet_v2", "MV2Model"),
2122
"mv3": ("mobilenet_v3", "MV3Model"),

examples/models/llama/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import LlamaModel
8+
9+
__all__ = [
10+
LlamaModel,
11+
]

examples/models/llama/model.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This examples demonstrates how to Export a llama 2 model in ExecuTorch.
5+
# Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
6+
#
7+
# Instructions:
8+
# 1. Follow https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md
9+
# to set up ExecuTorch.
10+
# 2. cd examples/third-party/llama
11+
# 3. pip install -e .
12+
# 4. Follow the instruction in https://github.com/facebookresearch/llama to download the model
13+
# 5. Go back to executorch/ root, run
14+
# python3 -m examples.export.export_example --model_name="llama"
15+
#
16+
17+
import logging
18+
import os
19+
from pathlib import Path
20+
import json
21+
import torch
22+
from torch import nn
23+
from typing import Any, Optional, Tuple
24+
import math
25+
import torch.nn.functional as F
26+
27+
from llama.model import ModelArgs, RMSNorm, repeat_kv
28+
29+
from examples.models.model_base import EagerModelBase
30+
31+
# Since ExecuTorch does not support complex Tensor data type,
32+
# use the following functions to have rotary embedding with real numbers.
33+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
34+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
35+
t = torch.arange(end, device=freqs.device)
36+
freqs = torch.outer(t, freqs).float()
37+
freqs_cos = torch.cos(freqs)
38+
freqs_sin = torch.sin(freqs)
39+
return freqs_cos, freqs_sin
40+
41+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
42+
ndim = x.ndim
43+
assert 0 <= 1 < ndim
44+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
45+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
46+
return freqs_cis.view(shape)
47+
48+
def apply_rotary_emb(
49+
xq: torch.Tensor,
50+
xk: torch.Tensor,
51+
freqs_cos: torch.Tensor,
52+
freqs_sin: torch.Tensor
53+
) -> Tuple[torch.Tensor, torch.Tensor]:
54+
55+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
56+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
57+
58+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
59+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
60+
61+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
62+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
63+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
64+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
65+
66+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
67+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
68+
69+
return xq_out.type_as(xq), xk_out.type_as(xk)
70+
71+
class Attention(nn.Module):
72+
def __init__(self, args: ModelArgs):
73+
super().__init__()
74+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
75+
assert args.n_heads % self.n_kv_heads == 0
76+
model_parallel_size = 1
77+
self.n_local_heads = args.n_heads // model_parallel_size
78+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
79+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
80+
self.head_dim = args.dim // args.n_heads
81+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
82+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
83+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
84+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
85+
86+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
87+
mask = torch.triu(mask, diagonal=1)
88+
self.register_buffer("mask", mask)
89+
90+
def forward(
91+
self,
92+
x: torch.Tensor,
93+
freqs_cos: torch.Tensor,
94+
freqs_sin: torch.Tensor,
95+
):
96+
bsz, seqlen, _ = x.shape
97+
98+
# QKV
99+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
100+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
101+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
102+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
103+
104+
# RoPE relative positional embeddings
105+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
106+
107+
# grouped multiquery attention: expand out keys and values
108+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
109+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
110+
111+
# make heads into a batch dimension
112+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
113+
xk = xk.transpose(1, 2)
114+
xv = xv.transpose(1, 2)
115+
116+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
117+
assert hasattr(self, 'mask')
118+
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
119+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
120+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
121+
122+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
123+
124+
output = self.wo(output)
125+
return output
126+
127+
128+
class FeedForward(nn.Module):
129+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
130+
super().__init__()
131+
hidden_dim = int(2 * hidden_dim / 3)
132+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
133+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
134+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
135+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
136+
137+
def forward(self, x):
138+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
139+
140+
class TransformerBlock(nn.Module):
141+
def __init__(self, layer_id: int, args: ModelArgs):
142+
super().__init__()
143+
self.n_heads = args.n_heads
144+
self.dim = args.dim
145+
self.head_dim = args.dim // args.n_heads
146+
self.attention = Attention(args)
147+
self.feed_forward = FeedForward(
148+
dim=args.dim,
149+
hidden_dim=4 * args.dim,
150+
multiple_of=args.multiple_of,
151+
)
152+
self.layer_id = layer_id
153+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
154+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
155+
156+
def forward(self, x, freqs_cos, freqs_sin):
157+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
158+
out = h + self.feed_forward.forward(self.ffn_norm(h))
159+
return out
160+
161+
162+
class Transformer(nn.Module):
163+
last_loss: Optional[torch.Tensor]
164+
165+
def __init__(self, params: ModelArgs):
166+
super().__init__()
167+
self.params = params
168+
self.vocab_size = params.vocab_size
169+
self.n_layers = params.n_layers
170+
171+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
172+
self.layers = torch.nn.ModuleList()
173+
for layer_id in range(params.n_layers):
174+
self.layers.append(TransformerBlock(layer_id, params))
175+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
176+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
177+
178+
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
179+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
180+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
181+
182+
183+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
184+
_bsz, seqlen = tokens.shape
185+
h = self.tok_embeddings(tokens)
186+
freqs_cos = self.freqs_cos[:seqlen]
187+
freqs_sin = self.freqs_sin[:seqlen]
188+
189+
for layer in self.layers:
190+
h = layer(h, freqs_cos, freqs_sin)
191+
# h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug
192+
193+
h = self.norm(h)
194+
195+
logits = self.output(h)
196+
return logits
197+
198+
# cur_path = Path().absolute()
199+
# # Follow the instruction in https://github.com/facebookresearch/llama to download the model
200+
# ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
201+
# device = 'cpu'
202+
# checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
203+
# with open(Path(ckpt_dir) / "params.json", "r") as f:
204+
# params = json.loads(f.read())
205+
# params['vocab_size'] = 32000
206+
# max_seq_len = 128
207+
# max_batch_size = 1
208+
# model_args: ModelArgs = ModelArgs(
209+
# max_seq_len=max_seq_len,
210+
# max_batch_size=max_batch_size,
211+
# **params,
212+
# )
213+
# model = Transformer(model_args)
214+
# model.load_state_dict(checkpoint, strict=False)
215+
# x = torch.tensor([[1]])
216+
# y = model.forward(x)
217+
# print(y)
218+
219+
220+
class LlamaModel(EagerModelBase):
221+
def __init__(self):
222+
cur_path = Path(__file__).absolute().parent
223+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
224+
ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
225+
device = 'cpu'
226+
checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
227+
with open(Path(ckpt_dir) / "params.json", "r") as f:
228+
params = json.loads(f.read())
229+
params['vocab_size'] = 32000
230+
max_seq_len = 128
231+
max_batch_size = 1
232+
model_args: ModelArgs = ModelArgs(
233+
max_seq_len=max_seq_len,
234+
max_batch_size=max_batch_size,
235+
**params,
236+
)
237+
self.model_ = Transformer(model_args)
238+
self.model_.load_state_dict(checkpoint, strict=False) # self.model_ = Transformer(gptconf)
239+
240+
# @staticmethod
241+
def get_eager_model(self):
242+
return self.model_
243+
244+
@staticmethod
245+
def get_example_inputs():
246+
return (torch.tensor([[1]]),)

0 commit comments

Comments
 (0)