Skip to content

Add llama model to examples #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
"llama": ("llama", "LlamaModel"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
"mv3": ("mobilenet_v3", "MV3Model"),
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import LlamaModel

__all__ = [
LlamaModel,
]
246 changes: 246 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This examples demonstrates how to Export a llama 2 model in ExecuTorch.
# Llama 2 is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
#
# Instructions:
# 1. Follow https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md
# to set up ExecuTorch.
# 2. cd examples/third-party/llama
# 3. pip install -e .
# 4. Follow the instruction in https://github.com/facebookresearch/llama to download the model
# 5. Go back to executorch/ root, run
# python3 -m examples.export.export_example --model_name="llama"
#

import logging
import os
from pathlib import Path
import json
import torch
from torch import nn
from typing import Any, Optional, Tuple
import math
import torch.nn.functional as F

from llama.model import ModelArgs, RMSNorm, repeat_kv

from examples.models.model_base import EagerModelBase

# Since ExecuTorch does not support complex Tensor data type,
# use the following functions to have rotary embedding with real numbers.
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)

def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
bsz, seqlen, _ = x.shape

# QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)
return output


class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x, freqs_cos, freqs_sin):
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out


class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]

def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)


def forward(self, tokens: torch.Tensor) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]

for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug

h = self.norm(h)

logits = self.output(h)
return logits

# cur_path = Path().absolute()
# # Follow the instruction in https://github.com/facebookresearch/llama to download the model
# ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
# device = 'cpu'
# checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
# with open(Path(ckpt_dir) / "params.json", "r") as f:
# params = json.loads(f.read())
# params['vocab_size'] = 32000
# max_seq_len = 128
# max_batch_size = 1
# model_args: ModelArgs = ModelArgs(
# max_seq_len=max_seq_len,
# max_batch_size=max_batch_size,
# **params,
# )
# model = Transformer(model_args)
# model.load_state_dict(checkpoint, strict=False)
# x = torch.tensor([[1]])
# y = model.forward(x)
# print(y)


class LlamaModel(EagerModelBase):
def __init__(self):
cur_path = Path(__file__).absolute().parent
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
ckpt_dir = cur_path.parent.parent / "third-party/llama/llama-2-7b"
device = 'cpu'
checkpoint = torch.load(Path(ckpt_dir) / "consolidated.00.pth", map_location=device)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
params['vocab_size'] = 32000
max_seq_len = 128
max_batch_size = 1
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
self.model_ = Transformer(model_args)
self.model_.load_state_dict(checkpoint, strict=False) # self.model_ = Transformer(gptconf)

# @staticmethod
def get_eager_model(self):
return self.model_

@staticmethod
def get_example_inputs():
return (torch.tensor([[1]]),)