Skip to content

Let Llama export using buck #1431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,95 @@

import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from examples.models.model_base import EagerModelBase

from llama.model import ModelArgs, repeat_kv, RMSNorm
from torch import nn

from ..model_base import EagerModelBase


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight


@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5

max_batch_size: int = 32
max_seq_len: int = 2048


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
t = torch.arange(end, device=freqs.device) # pyre-ignore
freqs = torch.outer(t, freqs).float() # pyre-ignore
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
Expand Down Expand Up @@ -155,8 +228,6 @@ def forward(self, x, freqs_cos, freqs_sin):


class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]

def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
Expand Down Expand Up @@ -194,14 +265,31 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:

class Llama2Model(EagerModelBase):
def __init__(self, **kwargs):
import pkg_resources

ckpt_dir = Path(__file__).absolute().parent

# Get the path to the resource file
params_path = (
Path(ckpt_dir) / kwargs["checkpoint"]
if "checkpoint" in kwargs
else pkg_resources.resource_filename(
"executorch.examples.portable.scripts", "demo_config.json"
)
)
checkpoint_path = (
Path(ckpt_dir) / kwargs["params"]
if "params" in kwargs
else pkg_resources.resource_filename(
"executorch.examples.portable.scripts", "demo_rand_params.pth"
)
)

# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
device = "cpu"
checkpoint = torch.load(
Path(ckpt_dir) / kwargs["checkpoint"], map_location=device
)
with open(Path(ckpt_dir) / kwargs["params"], "r") as f:
checkpoint = torch.load(checkpoint_path, map_location=device)
with open(params_path, "r") as f:
params = json.loads(f.read())
max_seq_len = 128
max_batch_size = 1
Expand Down