Skip to content

Commit ffc87a3

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Let Llama export using buck (#1431)
Summary: Pull Request resolved: #1431 These changes should allow us to iterate on llama2 using internal build systems Reviewed By: larryliu0820 Differential Revision: D52259757 fbshipit-source-id: 43c76fc045acc8984272beb0fdbfa8ff680339b8
1 parent 0ba14a1 commit ffc87a3

File tree

1 file changed

+99
-11
lines changed

1 file changed

+99
-11
lines changed

examples/models/llama2/model.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,95 @@
1010

1111
import json
1212
import math
13+
from dataclasses import dataclass
1314
from pathlib import Path
1415
from typing import Optional, Tuple
1516

1617
import torch
1718
import torch.nn.functional as F
1819

19-
from examples.models.model_base import EagerModelBase
20-
21-
from llama.model import ModelArgs, repeat_kv, RMSNorm
2220
from torch import nn
2321

22+
from ..model_base import EagerModelBase
23+
24+
25+
class RMSNorm(torch.nn.Module):
26+
def __init__(self, dim: int, eps: float = 1e-6):
27+
"""
28+
Initialize the RMSNorm normalization layer.
29+
30+
Args:
31+
dim (int): The dimension of the input tensor.
32+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
33+
34+
Attributes:
35+
eps (float): A small value added to the denominator for numerical stability.
36+
weight (nn.Parameter): Learnable scaling parameter.
37+
38+
"""
39+
super().__init__()
40+
self.eps = eps
41+
self.weight = nn.Parameter(torch.ones(dim))
42+
43+
def _norm(self, x):
44+
"""
45+
Apply the RMSNorm normalization to the input tensor.
46+
47+
Args:
48+
x (torch.Tensor): The input tensor.
49+
50+
Returns:
51+
torch.Tensor: The normalized tensor.
52+
53+
"""
54+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
55+
56+
def forward(self, x):
57+
"""
58+
Forward pass through the RMSNorm layer.
59+
60+
Args:
61+
x (torch.Tensor): The input tensor.
62+
63+
Returns:
64+
torch.Tensor: The output tensor after applying RMSNorm.
65+
66+
"""
67+
output = self._norm(x.float()).type_as(x)
68+
return output * self.weight
69+
70+
71+
@dataclass
72+
class ModelArgs:
73+
dim: int = 4096
74+
n_layers: int = 32
75+
n_heads: int = 32
76+
n_kv_heads: Optional[int] = None
77+
vocab_size: int = -1 # defined later by tokenizer
78+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
79+
ffn_dim_multiplier: Optional[float] = None
80+
norm_eps: float = 1e-5
81+
82+
max_batch_size: int = 32
83+
max_seq_len: int = 2048
84+
85+
86+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
87+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
88+
bs, slen, n_kv_heads, head_dim = x.shape
89+
if n_rep == 1:
90+
return x
91+
return (
92+
x[:, :, :, None, :]
93+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
94+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
95+
)
96+
2497

2598
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
2699
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
27-
t = torch.arange(end, device=freqs.device)
28-
freqs = torch.outer(t, freqs).float()
100+
t = torch.arange(end, device=freqs.device) # pyre-ignore
101+
freqs = torch.outer(t, freqs).float() # pyre-ignore
29102
freqs_cos = torch.cos(freqs)
30103
freqs_sin = torch.sin(freqs)
31104
return freqs_cos, freqs_sin
@@ -155,8 +228,6 @@ def forward(self, x, freqs_cos, freqs_sin):
155228

156229

157230
class Transformer(nn.Module):
158-
last_loss: Optional[torch.Tensor]
159-
160231
def __init__(self, params: ModelArgs):
161232
super().__init__()
162233
self.params = params
@@ -194,14 +265,31 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
194265

195266
class Llama2Model(EagerModelBase):
196267
def __init__(self, **kwargs):
268+
import pkg_resources
269+
197270
ckpt_dir = Path(__file__).absolute().parent
271+
272+
# Get the path to the resource file
273+
params_path = (
274+
Path(ckpt_dir) / kwargs["checkpoint"]
275+
if "checkpoint" in kwargs
276+
else pkg_resources.resource_filename(
277+
"executorch.examples.portable.scripts", "demo_config.json"
278+
)
279+
)
280+
checkpoint_path = (
281+
Path(ckpt_dir) / kwargs["params"]
282+
if "params" in kwargs
283+
else pkg_resources.resource_filename(
284+
"executorch.examples.portable.scripts", "demo_rand_params.pth"
285+
)
286+
)
287+
198288
# The example is using a dummy small model with random weights for demo purpose only.
199289
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
200290
device = "cpu"
201-
checkpoint = torch.load(
202-
Path(ckpt_dir) / kwargs["checkpoint"], map_location=device
203-
)
204-
with open(Path(ckpt_dir) / kwargs["params"], "r") as f:
291+
checkpoint = torch.load(checkpoint_path, map_location=device)
292+
with open(params_path, "r") as f:
205293
params = json.loads(f.read())
206294
max_seq_len = 128
207295
max_batch_size = 1

0 commit comments

Comments
 (0)