|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This examples demonstrates how to Export a llama 2 model in ExecuTorch. |
| 5 | +# Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved. |
| 6 | +# |
| 7 | +# Instructions: |
| 8 | +# 1. Follow https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md |
| 9 | +# to set up ExecuTorch. |
| 10 | +# 2. cd examples/third-party/llama |
| 11 | +# 3. pip install -e . |
| 12 | +# 4. Follow the instruction in https://github.com/facebookresearch/llama to download the model |
| 13 | +# 5. Go back to executorch/ root, run |
| 14 | +# python3 -m examples.export.export_example --model_name="llama" |
| 15 | +# |
| 16 | + |
| 17 | +import logging |
| 18 | +import os |
| 19 | +from pathlib import Path |
| 20 | +import json |
| 21 | +import torch |
| 22 | +from torch import nn |
| 23 | +from typing import Any, Optional, Tuple |
| 24 | +import math |
| 25 | +import torch.nn.functional as F |
| 26 | + |
| 27 | +from llama.model import ModelArgs, RMSNorm, repeat_kv |
| 28 | + |
| 29 | +from examples.models.model_base import EagerModelBase |
| 30 | + |
| 31 | +# Since ExecuTorch does not support complex Tensor data type, |
| 32 | +# use the following functions to have rotary embedding with real numbers. |
| 33 | +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
| 34 | + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| 35 | + t = torch.arange(end, device=freqs.device) |
| 36 | + freqs = torch.outer(t, freqs).float() |
| 37 | + freqs_cos = torch.cos(freqs) |
| 38 | + freqs_sin = torch.sin(freqs) |
| 39 | + return freqs_cos, freqs_sin |
| 40 | + |
| 41 | +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| 42 | + ndim = x.ndim |
| 43 | + assert 0 <= 1 < ndim |
| 44 | + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) |
| 45 | + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 46 | + return freqs_cis.view(shape) |
| 47 | + |
| 48 | +def apply_rotary_emb( |
| 49 | + xq: torch.Tensor, |
| 50 | + xk: torch.Tensor, |
| 51 | + freqs_cos: torch.Tensor, |
| 52 | + freqs_sin: torch.Tensor |
| 53 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 54 | + |
| 55 | + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) |
| 56 | + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) |
| 57 | + |
| 58 | + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) |
| 59 | + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) |
| 60 | + |
| 61 | + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin |
| 62 | + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos |
| 63 | + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin |
| 64 | + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos |
| 65 | + |
| 66 | + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) |
| 67 | + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) |
| 68 | + |
| 69 | + return xq_out.type_as(xq), xk_out.type_as(xk) |
| 70 | + |
| 71 | +class Attention(nn.Module): |
| 72 | + def __init__(self, args: ModelArgs): |
| 73 | + super().__init__() |
| 74 | + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads |
| 75 | + assert args.n_heads % self.n_kv_heads == 0 |
| 76 | + model_parallel_size = 1 |
| 77 | + self.n_local_heads = args.n_heads // model_parallel_size |
| 78 | + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size |
| 79 | + self.n_rep = self.n_local_heads // self.n_local_kv_heads |
| 80 | + self.head_dim = args.dim // args.n_heads |
| 81 | + self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) |
| 82 | + self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) |
| 83 | + self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) |
| 84 | + self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
| 85 | + |
| 86 | + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) |
| 87 | + mask = torch.triu(mask, diagonal=1) |
| 88 | + self.register_buffer("mask", mask) |
| 89 | + |
| 90 | + def forward( |
| 91 | + self, |
| 92 | + x: torch.Tensor, |
| 93 | + freqs_cos: torch.Tensor, |
| 94 | + freqs_sin: torch.Tensor, |
| 95 | + ): |
| 96 | + bsz, seqlen, _ = x.shape |
| 97 | + |
| 98 | + # QKV |
| 99 | + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
| 100 | + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 101 | + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 102 | + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 103 | + |
| 104 | + # RoPE relative positional embeddings |
| 105 | + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) |
| 106 | + |
| 107 | + # grouped multiquery attention: expand out keys and values |
| 108 | + xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) |
| 109 | + xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) |
| 110 | + |
| 111 | + # make heads into a batch dimension |
| 112 | + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) |
| 113 | + xk = xk.transpose(1, 2) |
| 114 | + xv = xv.transpose(1, 2) |
| 115 | + |
| 116 | + scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) |
| 117 | + assert hasattr(self, 'mask') |
| 118 | + scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) |
| 119 | + scores = F.softmax(scores.float(), dim=-1).type_as(xq) |
| 120 | + output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) |
| 121 | + |
| 122 | + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 123 | + |
| 124 | + output = self.wo(output) |
| 125 | + return output |
| 126 | + |
| 127 | + |
| 128 | +class FeedForward(nn.Module): |
| 129 | + def __init__(self, dim: int, hidden_dim: int, multiple_of: int): |
| 130 | + super().__init__() |
| 131 | + hidden_dim = int(2 * hidden_dim / 3) |
| 132 | + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| 133 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 134 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 135 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 136 | + |
| 137 | + def forward(self, x): |
| 138 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 139 | + |
| 140 | +class TransformerBlock(nn.Module): |
| 141 | + def __init__(self, layer_id: int, args: ModelArgs): |
| 142 | + super().__init__() |
| 143 | + self.n_heads = args.n_heads |
| 144 | + self.dim = args.dim |
| 145 | + self.head_dim = args.dim // args.n_heads |
| 146 | + self.attention = Attention(args) |
| 147 | + self.feed_forward = FeedForward( |
| 148 | + dim=args.dim, |
| 149 | + hidden_dim=4 * args.dim, |
| 150 | + multiple_of=args.multiple_of, |
| 151 | + ) |
| 152 | + self.layer_id = layer_id |
| 153 | + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 154 | + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 155 | + |
| 156 | + def forward(self, x, freqs_cos, freqs_sin): |
| 157 | + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) |
| 158 | + out = h + self.feed_forward.forward(self.ffn_norm(h)) |
| 159 | + return out |
| 160 | + |
| 161 | + |
| 162 | +class Transformer(nn.Module): |
| 163 | + last_loss: Optional[torch.Tensor] |
| 164 | + |
| 165 | + def __init__(self, params: ModelArgs): |
| 166 | + super().__init__() |
| 167 | + self.params = params |
| 168 | + self.vocab_size = params.vocab_size |
| 169 | + self.n_layers = params.n_layers |
| 170 | + |
| 171 | + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) |
| 172 | + self.layers = torch.nn.ModuleList() |
| 173 | + for layer_id in range(params.n_layers): |
| 174 | + self.layers.append(TransformerBlock(layer_id, params)) |
| 175 | + self.norm = RMSNorm(params.dim, eps=params.norm_eps) |
| 176 | + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) |
| 177 | + |
| 178 | + freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) |
| 179 | + self.register_buffer("freqs_cos", freqs_cos, persistent=False) |
| 180 | + self.register_buffer("freqs_sin", freqs_sin, persistent=False) |
| 181 | + |
| 182 | + |
| 183 | + def forward(self, tokens: torch.Tensor) -> torch.Tensor: |
| 184 | + _bsz, seqlen = tokens.shape |
| 185 | + h = self.tok_embeddings(tokens) |
| 186 | + freqs_cos = self.freqs_cos[:seqlen] |
| 187 | + freqs_sin = self.freqs_sin[:seqlen] |
| 188 | + |
| 189 | + for layer in self.layers: |
| 190 | + h = layer(h, freqs_cos, freqs_sin) |
| 191 | + # h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug |
| 192 | + |
| 193 | + h = self.norm(h) |
| 194 | + |
| 195 | + logits = self.output(h) |
| 196 | + return logits |
| 197 | + |
| 198 | +# cur_path = Path().absolute() |
| 199 | +# # Follow the instruction in https://github.com/facebookresearch/llama to download the model |
| 200 | +# ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b" |
| 201 | +# device = 'cpu' |
| 202 | +# checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device) |
| 203 | +# with open(Path(ckpt_dir) / "params.json", "r") as f: |
| 204 | +# params = json.loads(f.read()) |
| 205 | +# params['vocab_size'] = 32000 |
| 206 | +# max_seq_len = 128 |
| 207 | +# max_batch_size = 1 |
| 208 | +# model_args: ModelArgs = ModelArgs( |
| 209 | +# max_seq_len=max_seq_len, |
| 210 | +# max_batch_size=max_batch_size, |
| 211 | +# **params, |
| 212 | +# ) |
| 213 | +# model = Transformer(model_args) |
| 214 | +# model.load_state_dict(checkpoint, strict=False) |
| 215 | +# x = torch.tensor([[1]]) |
| 216 | +# y = model.forward(x) |
| 217 | +# print(y) |
| 218 | + |
| 219 | + |
| 220 | +class LlamaModel(EagerModelBase): |
| 221 | + def __init__(self): |
| 222 | + cur_path = Path(__file__).absolute().parent |
| 223 | + # Follow the instruction in https://github.com/facebookresearch/llama to download the model |
| 224 | + ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b" |
| 225 | + device = 'cpu' |
| 226 | + checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device) |
| 227 | + with open(Path(ckpt_dir) / "params.json", "r") as f: |
| 228 | + params = json.loads(f.read()) |
| 229 | + params['vocab_size'] = 32000 |
| 230 | + max_seq_len = 128 |
| 231 | + max_batch_size = 1 |
| 232 | + model_args: ModelArgs = ModelArgs( |
| 233 | + max_seq_len=max_seq_len, |
| 234 | + max_batch_size=max_batch_size, |
| 235 | + **params, |
| 236 | + ) |
| 237 | + self.model_ = Transformer(model_args) |
| 238 | + self.model_.load_state_dict(checkpoint, strict=False) # self.model_ = Transformer(gptconf) |
| 239 | + |
| 240 | + # @staticmethod |
| 241 | + def get_eager_model(self): |
| 242 | + return self.model_ |
| 243 | + |
| 244 | + @staticmethod |
| 245 | + def get_example_inputs(): |
| 246 | + return (torch.tensor([[1]]),) |
0 commit comments