Skip to content

Llama2 model cleanup #5859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/models/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from pathlib import Path
from typing import Any, Dict, Optional


def get_default_model_resource_dir(model_file_path: str) -> Path:
"""
Get the default path to resouce files (which contain files such as the
checkpoint and param files), either:
1. Uses the path from pkg_resources, only works with buck2
2. Uses default path located in examples/models/llama2/params

Expected to be called from with a `model.py` file located in a
`executorch/examples/models/<model_name>` directory.

Args:
model_file_path: The file path to the eager model definition.
For example, `executorch/examples/models/llama2/model.py`,
where `executorch/examples/models/llama2` contains all
the llama2-related files.

Returns:
The path to the resource directory containing checkpoint, params, etc.
"""

try:
import pkg_resources

# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
# pyre-ignore
from executorch.examples.models.llama2 import params # noqa

# Get the model name from the cwd, assuming that this module is called from a path such as
# examples/models/<model_name>/model.py.
model_name = Path(model_file_path).parent.name
resource_dir = Path(
pkg_resources.resource_filename(
f"executorch.examples.models.{model_name}", "params"
)
)
except:
# 2nd way.
resource_dir = Path(model_file_path).absolute().parent / "params"

return resource_dir


def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
"""
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
"""
dtype = None
if len(checkpoint) > 0:
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
for key, value in checkpoint.items()
if value.dtype != dtype
]
if len(mismatched_dtypes) > 0:
raise ValueError(
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
return dtype
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama2:llama_transformer",
"//executorch/examples/models:checkpoint",
],
)

Expand Down
84 changes: 29 additions & 55 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

import json
import os
from pathlib import Path
from typing import Dict, Tuple

import torch
from executorch.examples.models.checkpoint import (
get_checkpoint_dtype,
get_default_model_resource_dir,
)

from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer

Expand All @@ -30,48 +34,31 @@ def convert_to_llama_checkpoint(**kwargs):

class Llama2Model(EagerModelBase):
def __init__(self, **kwargs):
import pkg_resources

# default path to the resource file
# It currently supports 3 ways of specifying the checkpoint location:
# 1. Using default path locates in examples/models/llama2/params
# 2. Passing in the checkpoint path and params via kwargs
# 3. Using the path from pkg_resources, only works with buck2
try:
# The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
# pyre-ignore
from executorch.examples.models.llama2 import params

ckpt_dir = Path(
pkg_resources.resource_filename(
"executorch.examples.models.llama2", "params"
)
)
except:
# The 1st way
ckpt_dir = Path(__file__).absolute().parent / "params"

# Check if checkpoint_dir was provided for a sharded checkpoint.
checkpoint_dir = kwargs.get("checkpoint_dir", None)
resource_dir = get_default_model_resource_dir(__file__)

# Use single checkpoint file.
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
checkpoint_path = kwargs.get(
"checkpoint", resource_dir / "demo_rand_params.pth"
)
params_path = kwargs.get("params", resource_dir / "demo_config.json")

params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
# Check if checkpoint_dir was provided for a sharded checkpoint.
checkpoint_dir = kwargs.get("checkpoint_dir", None)

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)

self.max_seq_len = kwargs.get("max_seq_len", 128)
self.args = kwargs.get("args", None)

# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
device = "cpu"
# flake8: noqa: TOR102
cps = []
# Load sharded checkpoint.
if checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
checkpoint_path = None
Expand All @@ -98,8 +85,11 @@ def __init__(self, **kwargs):
else:
# Do not duplicate layers shared between each checkpoint.
checkpoint[key] = cps[0][key]
# Load single checkpoint.
else:
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)

# If given checkpoint is fairseq, convert to llama checkpoint.
fairseq2_checkpoint = kwargs.get("fairseq2", False)
if fairseq2_checkpoint:
print("Using fairseq2 checkpoint")
Expand All @@ -108,12 +98,12 @@ def __init__(self, **kwargs):
# NB: some checkpoint contains a "model" field, which is the actual weights dict
checkpoint = checkpoint["model"]

# Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
if (not fairseq2_checkpoint) and checkpoint.get(
"final_proj.weight", None
) is not None:
print(
raise ValueError(
"""

************************************************************
This looks like a Fairseq2 checkpoint (based on the presence
of `final_proj.weight`.
Expand All @@ -125,44 +115,28 @@ def __init__(self, **kwargs):
"""
)

# get checkpoint dtype
self.dtype = None
if len(checkpoint) > 0:
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
self.dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
for key, value in checkpoint.items()
if value.dtype != self.dtype
]
if len(mismatched_dtypes) > 0:
print(
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
# Get checkpoint dtype.
self.dtype = get_checkpoint_dtype(checkpoint)

with open(params_path, "r") as f:
params = json.loads(f.read())
output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
output_prune_map = json.load(f)
# change keys from string to int (json only supports string keys)
# Change keys from string to int (json only supports string keys).
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
max_seq_len = self.max_seq_len
max_batch_size = 1

model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_batch_size=1,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
generate_full_logits=self.generate_full_logits,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
)
if kwargs.get("fairseq2", False):
print("Using fairseq2 checkpoint")
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
if kwargs.get("verbose", False):
print("============= weights ================")
print("{key} : {weights.numel()} : {weights.size()}")
Expand Down Expand Up @@ -234,13 +208,13 @@ def __init__(self, **kwargs):
print(unexpected)
print("============= /unexpected ================")

# prune the output layer if output_prune_map is provided
# Prune the output layer if output_prune_map is provided
if output_prune_map is not None:
from .source_transformation.prune_output import prune_output_vocab

self.model_ = prune_output_vocab(self.model_, output_prune_map)

def get_eager_model(self):
def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
# convert to the type of the provided checkpoint
# input and output are torch.long, so signature unchanged
Expand Down
Loading