Skip to content

Commit 875261d

Browse files
author
Martin Yuan
committed
Add llama model to examples
ghstack-source-id: a0d4d7d Pull Request resolved: #473
1 parent c386d4c commit 875261d

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
1717
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
19+
"llama": ("llama", "LlamaModel"),
1920
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2021
"mv2": ("mobilenet_v2", "MV2Model"),
2122
"mv3": ("mobilenet_v3", "MV3Model"),

examples/models/llama/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import LlamaModel
8+
9+
__all__ = [
10+
LlamaModel,
11+
]

examples/models/llama/model.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
5+
#
6+
7+
import logging
8+
import os
9+
from pathlib import Path
10+
import json
11+
import torch
12+
from torch import nn
13+
from typing import Any, Optional, Tuple
14+
import math
15+
import torch.nn.functional as F
16+
17+
from llama.model import ModelArgs, RMSNorm, repeat_kv
18+
19+
from examples.models.model_base import EagerModelBase
20+
21+
# Since ExecuTorch does not support complex Tensor data type,
22+
# use the following functions to have rotary embedding with real numbers.
23+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
24+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
25+
t = torch.arange(end, device=freqs.device)
26+
freqs = torch.outer(t, freqs).float()
27+
freqs_cos = torch.cos(freqs)
28+
freqs_sin = torch.sin(freqs)
29+
return freqs_cos, freqs_sin
30+
31+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32+
ndim = x.ndim
33+
assert 0 <= 1 < ndim
34+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
35+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
36+
return freqs_cis.view(shape)
37+
38+
def apply_rotary_emb(
39+
xq: torch.Tensor,
40+
xk: torch.Tensor,
41+
freqs_cos: torch.Tensor,
42+
freqs_sin: torch.Tensor
43+
) -> Tuple[torch.Tensor, torch.Tensor]:
44+
45+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
46+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
47+
48+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
49+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
50+
51+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
52+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
53+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
54+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
55+
56+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
57+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
58+
59+
return xq_out.type_as(xq), xk_out.type_as(xk)
60+
61+
class Attention(nn.Module):
62+
def __init__(self, args: ModelArgs):
63+
super().__init__()
64+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
65+
assert args.n_heads % self.n_kv_heads == 0
66+
model_parallel_size = 1
67+
self.n_local_heads = args.n_heads // model_parallel_size
68+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
69+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
70+
self.head_dim = args.dim // args.n_heads
71+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
72+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
73+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
74+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
75+
76+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
77+
mask = torch.triu(mask, diagonal=1)
78+
self.register_buffer("mask", mask)
79+
80+
def forward(
81+
self,
82+
x: torch.Tensor,
83+
freqs_cos: torch.Tensor,
84+
freqs_sin: torch.Tensor,
85+
):
86+
bsz, seqlen, _ = x.shape
87+
88+
# QKV
89+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
90+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
91+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
92+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
93+
94+
# RoPE relative positional embeddings
95+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
96+
97+
# grouped multiquery attention: expand out keys and values
98+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
99+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
100+
101+
# make heads into a batch dimension
102+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
103+
xk = xk.transpose(1, 2)
104+
xv = xv.transpose(1, 2)
105+
106+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
107+
assert hasattr(self, 'mask')
108+
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
109+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
110+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
111+
112+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
113+
114+
output = self.wo(output)
115+
return output
116+
117+
118+
class FeedForward(nn.Module):
119+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
120+
super().__init__()
121+
hidden_dim = int(2 * hidden_dim / 3)
122+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
123+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
124+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
125+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
126+
127+
def forward(self, x):
128+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
129+
130+
class TransformerBlock(nn.Module):
131+
def __init__(self, layer_id: int, args: ModelArgs):
132+
super().__init__()
133+
self.n_heads = args.n_heads
134+
self.dim = args.dim
135+
self.head_dim = args.dim // args.n_heads
136+
self.attention = Attention(args)
137+
self.feed_forward = FeedForward(
138+
dim=args.dim,
139+
hidden_dim=4 * args.dim,
140+
multiple_of=args.multiple_of,
141+
)
142+
self.layer_id = layer_id
143+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
144+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
145+
146+
def forward(self, x, freqs_cos, freqs_sin):
147+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
148+
out = h + self.feed_forward.forward(self.ffn_norm(h))
149+
return out
150+
151+
152+
class Transformer(nn.Module):
153+
last_loss: Optional[torch.Tensor]
154+
155+
def __init__(self, params: ModelArgs):
156+
super().__init__()
157+
self.params = params
158+
self.vocab_size = params.vocab_size
159+
self.n_layers = params.n_layers
160+
161+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
162+
self.layers = torch.nn.ModuleList()
163+
for layer_id in range(params.n_layers):
164+
self.layers.append(TransformerBlock(layer_id, params))
165+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
166+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
167+
168+
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
169+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
170+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
171+
172+
173+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
174+
_bsz, seqlen = tokens.shape
175+
h = self.tok_embeddings(tokens)
176+
freqs_cos = self.freqs_cos[:seqlen]
177+
freqs_sin = self.freqs_sin[:seqlen]
178+
179+
# for layer in self.layers:
180+
# h = layer(h, freqs_cos, freqs_sin)
181+
h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug
182+
183+
h = self.norm(h)
184+
185+
logits = self.output(h)
186+
return logits
187+
188+
# cur_path = Path().absolute()
189+
# # Follow the instruction in https://github.com/facebookresearch/llama to download the model
190+
# ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
191+
# device = 'cpu'
192+
# checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
193+
# with open(Path(ckpt_dir) / "params.json", "r") as f:
194+
# params = json.loads(f.read())
195+
# params['vocab_size'] = 32000
196+
# max_seq_len = 128
197+
# max_batch_size = 1
198+
# model_args: ModelArgs = ModelArgs(
199+
# max_seq_len=max_seq_len,
200+
# max_batch_size=max_batch_size,
201+
# **params,
202+
# )
203+
# model = Transformer(model_args)
204+
# model.load_state_dict(checkpoint, strict=False)
205+
# x = torch.tensor([[1]])
206+
# y = model.forward(x)
207+
# print(y)
208+
209+
210+
class LlamaModel(EagerModelBase):
211+
def __init__(self):
212+
cur_path = Path(__file__).absolute().parent
213+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
214+
ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
215+
device = 'cpu'
216+
checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
217+
with open(Path(ckpt_dir) / "params.json", "r") as f:
218+
params = json.loads(f.read())
219+
params['vocab_size'] = 32000
220+
max_seq_len = 128
221+
max_batch_size = 1
222+
model_args: ModelArgs = ModelArgs(
223+
max_seq_len=max_seq_len,
224+
max_batch_size=max_batch_size,
225+
**params,
226+
)
227+
self.model_ = Transformer(model_args)
228+
self.model_.load_state_dict(checkpoint, strict=False) # self.model_ = Transformer(gptconf)
229+
230+
# @staticmethod
231+
def get_eager_model(self):
232+
return self.model_
233+
234+
@staticmethod
235+
def get_example_inputs():
236+
return (torch.tensor([[1]]),)

0 commit comments

Comments
 (0)