Skip to content

[DRAFT] Changes to native runner to run TorchTune Lllama #6075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7f81e00
Changes to native runner to run tt
jackzhxng Oct 9, 2024
0b5a9a7
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
a9647d2
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
fa3b1d2
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
e8715ba
Lint
jackzhxng Oct 8, 2024
a6f96a2
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
328c72c
Remove future implementation
jackzhxng Oct 5, 2024
ec80bba
Lint
jackzhxng Oct 15, 2024
c9bbe12
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
99d5bfb
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
1fb2236
Torchtune llama3_2_vision model in ET, no quantization
jackzhxng Oct 5, 2024
e0c4b8a
Fix vision model example input
jackzhxng Oct 8, 2024
e145bd1
Lint
jackzhxng Oct 22, 2024
ed906cb
Kv cache
jackzhxng Oct 25, 2024
6dd47e7
Merge branch 'main' into jz/tt-llama
jackzhxng Oct 25, 2024
1825972
Update READMEs
jackzhxng Oct 25, 2024
196499a
Change model default arg
jackzhxng Oct 25, 2024
96ba40b
Update eager runner and eval llama
jackzhxng Oct 25, 2024
18a82e1
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
0f3035d
Fix tests
jackzhxng Oct 25, 2024
e677e14
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
b1f6678
Fix tests again
jackzhxng Oct 28, 2024
13d004b
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 28, 2024
c79b773
Strict = True
jackzhxng Oct 31, 2024
b8ff8e2
Things work
jackzhxng Oct 31, 2024
25ec7ce
Merge branch 'jz/tt-llama-rebased' into jz/native-runner-tt
jackzhxng Oct 31, 2024
6e38763
Clip logits if torchtune
jackzhxng Oct 31, 2024
7a7041d
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Oct 31, 2024
96d5798
Fix
jackzhxng Oct 31, 2024
f275e2e
Kv cache by default is false
jackzhxng Nov 1, 2024
37011d3
Clean up
jackzhxng Nov 1, 2024
de45c48
Strict = True
jackzhxng Oct 31, 2024
2fe7bd8
Merge branch 'main' into jz/tt-llama-2
jackzhxng Nov 13, 2024
64dcbda
Lint
jackzhxng Nov 13, 2024
a89d6b2
Fix merge
jackzhxng Nov 13, 2024
e1ec74c
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Nov 13, 2024
84422d9
Fixes
jackzhxng Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from executorch.devtools.etrecord import generate_etrecord

from executorch.examples.models.llama.llama_transformer import ModelArgs

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

from executorch.extension.llm.export.partitioner_lib import (
Expand Down Expand Up @@ -82,7 +80,7 @@


EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
TORCHTUNE_DEFINED_MODELS = []
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]


class WeightType(Enum):
Expand Down Expand Up @@ -138,7 +136,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--model",
default="llama3",
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
)
parser.add_argument(
"-E",
Expand Down Expand Up @@ -815,16 +813,18 @@ def _load_llama_model_metadata(
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
model_args: ModelArgs,
max_seq_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
):
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
metadata = {
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": model_args.max_seq_len,
"get_n_layers": model_args.n_layers,
"get_vocab_size": model_args.vocab_size,
"get_max_seq_len": max_seq_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
"enable_dynamic_shape": enable_dynamic_shape,
Expand Down Expand Up @@ -881,27 +881,29 @@ def _load_llama_model(
module_name = "llama"
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
elif modelname in TORCHTUNE_DEFINED_MODELS:
raise NotImplementedError(
"Torchtune Llama models are not yet supported in ExecuTorch export."
)
if modelname == "llama3_2_vision":
module_name = "llama3_2_vision"
model_class_name = "Llama3_2Decoder"
else:
raise ValueError(f"{modelname} is not a valid Llama model.")

model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
module_name,
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
args=args,
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
EagerModelFactory.create_model(
module_name,
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
args=args,
)
)
if dtype_override:
assert isinstance(
Expand Down Expand Up @@ -933,12 +935,13 @@ def _load_llama_model(
return LLMEdgeManager(
model=model,
modelname=modelname,
max_seq_len=model.params.max_seq_len,
max_seq_len=model.max_seq_len,
dtype=dtype,
use_kv_cache=use_kv_cache,
generate_full_logits=generate_full_logits,
example_inputs=example_inputs,
example_kwarg_inputs=example_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
enable_dynamic_shape=enable_dynamic_shape,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
Expand All @@ -951,7 +954,9 @@ def _load_llama_model(
use_kv_cache,
use_sdpa_with_kv_cache,
enable_dynamic_shape,
model.params,
model.max_seq_len,
model.n_layers,
model.vocab_size,
metadata_str,
),
args=args,
Expand Down
13 changes: 6 additions & 7 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from typing import Optional

import torch

from executorch.examples.models.llama.export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
TORCHTUNE_DEFINED_MODELS,
)
from executorch.examples.models.llama.llama_transformer import ModelArgs
from executorch.examples.models.llama.runner.generation import LlamaRunner
from executorch.extension.llm.export.builder import LLMEdgeManager

Expand All @@ -26,15 +27,13 @@ class EagerLlamaRunner(LlamaRunner):
def __init__(self, args):
with open(args.params, "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
super().__init__(
tokenizer_path=args.tokenizer_path,
max_seq_len=args.max_seq_length,
max_batch_size=1,
use_kv_cache=args.use_kv_cache,
**params,
)
super().__init__(
tokenizer_path=args.tokenizer_path,
model_args=model_args,
vocab_size=params["vocab_size"],
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
device="cuda" if torch.cuda.is_available() else "cpu",
)
manager: LLMEdgeManager = _prepare_for_llama_export(args)
Expand Down
56 changes: 46 additions & 10 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs
from executorch.extension.llm.tokenizer.utils import get_tokenizer


Expand Down Expand Up @@ -47,11 +46,35 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:


class LlamaRunner(ABC):
def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"):
self.params = model_args
def __init__(
self,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
use_kv_cache: bool,
vocab_size: int,
has_full_logits: bool = False,
device: str = "cpu",
):
"""
Constructor.

Args:
tokenizer_path: path to tokenizer.model file.
max_seq_len: max length of the output sequence, after which the output will be clipped.
max_batch_size: max batch size.
use_kv_cache: whether to use a KV cache.
vocab_size: number of items in the vocab.
has_full_logits: whether the model returns the full logits or only returns the last logit.
device: device to run the runner on.
"""
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.use_kv_cache = use_kv_cache
self.tokenizer = get_tokenizer(tokenizer_path)
assert model_args.vocab_size == self.tokenizer.n_words
self.has_full_logits = has_full_logits
self.device = device
assert vocab_size == self.tokenizer.n_words

@abstractmethod
def forward(
Expand All @@ -75,17 +98,22 @@ def generate( # noqa: C901
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
input_pos=(
torch.tensor([pos_base], dtype=torch.long, device=self.device)
if self.params.use_kv_cache
if self.use_kv_cache
else None
),
)

current_token = next_token(logits, temperature, top_p)
if self.has_full_logits:
current_token = next_token(logits[:, -1, :], temperature, top_p)
else:
current_token = next_token(logits, temperature, top_p)
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
tokens = prompt_tokens + [current_token]

i = 0
while len(tokens) < max_seq_len:
if self.params.use_kv_cache:
print(f"{i} out of {self.max_seq_len} max tokens generated")
if self.use_kv_cache:
logits = self.forward(
tokens=torch.tensor(
[[current_token]], dtype=torch.long, device=self.device
Expand All @@ -100,13 +128,21 @@ def generate( # noqa: C901
logits = self.forward(
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
)
current_token = next_token(logits, temperature, top_p)

# If the logits aren't already clipped to only contain the last logit, clip them.
if self.has_full_logits:
current_token = next_token(logits[:, -1, :], temperature, top_p)
else:
current_token = next_token(logits, temperature, top_p)
tokens.append(current_token)

if current_token == self.tokenizer.eos_id or (
hasattr(self.tokenizer, "stop_tokens")
and current_token in self.tokenizer.stop_tokens
):
break

i += 1
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
print("\n")

Expand Down Expand Up @@ -136,7 +172,7 @@ def text_completion(
"""
return self.generate(
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
max_seq_len=self.params.max_seq_len,
max_seq_len=self.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=echo,
Expand Down Expand Up @@ -171,7 +207,7 @@ def chat_completion(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
max_seq_len=self.params.max_seq_len,
max_seq_len=self.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=True,
Expand Down
25 changes: 18 additions & 7 deletions examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@

import torch

from examples.models.llama.llama_transformer import ModelArgs
from executorch.examples.models.llama.export_llama_lib import (
EXECUTORCH_DEFINED_MODELS,
TORCHTUNE_DEFINED_MODELS,
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

from executorch.examples.models.llama.runner.generation import LlamaRunner

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.kernels import quantized # noqa

from .generation import LlamaRunner


class NativeLlamaRunner(LlamaRunner):
"""
Expand All @@ -31,13 +35,14 @@ class NativeLlamaRunner(LlamaRunner):
def __init__(self, args):
with open(args.params, "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
super().__init__(
tokenizer_path=args.tokenizer,
max_seq_len=args.max_len,
max_batch_size=1,
use_kv_cache=args.kv_cache,
**params,
vocab_size=params["vocab_size"],
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
)
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
self.model = _load_for_executorch(args.pte)

def forward(
Expand All @@ -53,8 +58,15 @@ def forward(


def build_args_parser() -> argparse.ArgumentParser:
# TODO: merge these with build_args_parser from export_llama_lib.
parser = argparse.ArgumentParser()

parser.add_argument(
"--model",
default="llama3",
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
)

parser.add_argument(
"-f",
"--pte",
Expand Down Expand Up @@ -89,7 +101,6 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-kv",
"--kv_cache",
default=True,
action="store_true",
)

Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama3_2_vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import Llama3_2Decoder

__all__ = [Llama3_2Decoder]
Loading
Loading