Skip to content

Move builder.py into extension #4188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/transforms/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ runtime.python_library(
visibility = [
"//executorch/backends/...",
"//executorch/examples/...",
"//executorch/extension/llm/...",
],
deps = [
"//caffe2:torch",
Expand Down
4 changes: 0 additions & 4 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ runtime.python_binary(
runtime.python_library(
name = "export_library",
srcs = [
"builder.py",
"export_llama.py",
"export_llama_lib.py",
"model.py",
Expand All @@ -82,13 +81,10 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
"//executorch/exir:lib",
"//executorch/extension/llm/export:export_lib",
"//executorch/extension/export_util:export_util",
# one definition has to be included in the user of the libarary
# depending on what library the client wants to use
# "//executorch/extension/pybindings:aten_lib",
Expand Down
5 changes: 3 additions & 2 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
Tokenizer as SentencePieceTokenizer,
)

from executorch.extension.llm.export import LLMEdgeManager

from lm_eval.api.model import LM

from .builder import LlamaEdgeManager
from .export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
Expand Down Expand Up @@ -130,7 +131,7 @@ def gen_eval_wrapper(

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)

if len(quantizers) != 0:
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)
Expand Down
23 changes: 11 additions & 12 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from executorch.examples.models.llama2.llama_transformer import ModelArgs

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

from executorch.extension.llm.export.partitioner_lib import (
get_coreml_partitioner,
get_mps_partitioner,
Expand All @@ -40,8 +42,6 @@
from executorch.util.activation_memory_profiler import generate_memory_trace

from ..model_factory import EagerModelFactory

from .builder import DType, LlamaEdgeManager
from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
Expand Down Expand Up @@ -333,12 +333,12 @@ def export_llama(modelname, args) -> str:
return filename


def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
"""
Helper function for export_llama. Loads the model from checkpoint and params,
and sets up a LlamaEdgeManager with initial transforms and dtype conversion.
and sets up a LLMEdgeManager with initial transforms and dtype conversion.

Returns a LlamaEdgeManager prior to calling export_to_edge with quantizers
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
"""

# load model from checkpoint and params.json
Expand Down Expand Up @@ -429,7 +429,7 @@ def _validate_args(args):
)


def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

Expand Down Expand Up @@ -579,12 +579,12 @@ def _load_llama_model(
verbose: bool = False,
max_seq_len: int = 128,
metadata_str: Optional[str] = None,
) -> "LlamaEdgeManager":
) -> "LLMEdgeManager":
"""
A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
can help further lower the model to ExecuTorch.
Returns:
An instance of LlamaEdgeManager which contains the eager mode model.
An instance of LLMEdgeManager which contains the eager mode model.
"""
assert (
checkpoint or checkpoint_dir
Expand Down Expand Up @@ -622,13 +622,12 @@ def _load_llama_model(
else:
raise ValueError(f"Unsupported dtype {dtype}")

return LlamaEdgeManager(
return LLMEdgeManager(
model=model,
modelname=modelname,
weight_type=weight_type,
max_seq_len=model.params.max_seq_len,
dtype=dtype,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
example_inputs=example_inputs,
enable_dynamic_shape=enable_dynamic_shape,
verbose=verbose,
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch.nn as nn
import torch.nn.functional as F

from sentencepiece import SentencePieceProcessor
from executorch.extension.llm.export.builder import DType

from ..builder import DType
from sentencepiece import SentencePieceProcessor

try:
from fairseq2.nn.embedding import (
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/llama2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
)
from executorch.examples.models.llama2.builder import DType
from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs
from executorch.examples.qualcomm.scripts.utils import (
make_output_dir,
Expand All @@ -41,6 +40,7 @@
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.program._program import _get_updated_graph_signature
from executorch.extension.llm.export.builder import DType

from sentencepiece import SentencePieceProcessor
from torch.ao.quantization.observer import MinMaxObserver
Expand Down
4 changes: 4 additions & 0 deletions extension/llm/export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ define_common_targets()
runtime.python_library(
name = "export_lib",
srcs = [
"builder.py",
"partitioner_lib.py",
"quantizer_lib.py",
],
Expand All @@ -26,8 +27,11 @@ runtime.python_library(
"//executorch/backends/apple/coreml:backend",
"//executorch/backends/apple/coreml:partitioner",
"//executorch/backends/apple/mps:partitioner",
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_details",
"//executorch/extension/export_util:export_util",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,18 @@ def to_torch_dtype(self) -> torch.dtype:
return mapping[self]


class LlamaEdgeManager:
class LLMEdgeManager:
"""
Host a torch.nn.Module for Llama model and facilitates exporting to ExecuTorch.
Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
"""

def __init__(
self,
model,
modelname,
weight_type,
max_seq_len,
dtype,
use_kv_cache,
use_sdpa_with_kv_cache,
example_inputs,
enable_dynamic_shape: bool = False,
verbose: bool = False,
Expand All @@ -74,12 +73,11 @@ def __init__(
# graph module returned from capture_pre_autograd_graph
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
self.modelname = modelname
self.weight_type = weight_type
self.max_seq_len = max_seq_len
self.dtype = dtype
self.example_inputs = example_inputs
self.use_kv_cache = use_kv_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.use_sdpa_with_kv_cache = use_sdpa_with_kv_cache
self.verbose = verbose
self.metadata = metadata
self.applied_source_transforms = []
Expand All @@ -88,7 +86,7 @@ def __init__(
self.output_dir = "."
self._saved_pte_filename = None

def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager":
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Set the directory where the .pte file will be saved.
Args:
Expand All @@ -97,7 +95,7 @@ def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager":
self.output_dir = output_dir
return self

def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
"""
Convert the model to the specified dtype.
Args:
Expand All @@ -115,7 +113,7 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":

def source_transform(
self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
) -> "LlamaEdgeManager":
) -> "LLMEdgeManager":
"""
Apply source transforms to the model. The transforms are callables that
takes nn.Module as input and returns nn.Module.
Expand All @@ -132,7 +130,7 @@ def source_transform(
return self

def _get_dynamic_shape(self) -> Any:
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
if self.use_kv_cache:
if self.enable_dynamic_shape:
return ({1: dim}, {0: dim})
Expand All @@ -149,7 +147,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
)
return edge_config

def capture_pre_autograd_graph(self) -> "LlamaEdgeManager":
def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
dynamic_shape = self._get_dynamic_shape()
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
Expand All @@ -159,11 +157,9 @@ def capture_pre_autograd_graph(self) -> "LlamaEdgeManager":
)
return self

def pt2e_quantize(
self, quantizers: Optional[List[Quantizer]]
) -> "LlamaEdgeManager":
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Quantize the model via pt2e flow and retrieve LlamaEdgeManager including the quantized model.
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
Args:
quantizers (Optional[List[Quantizer]]): A list of quantizers.
"""
Expand Down Expand Up @@ -193,9 +189,9 @@ def pt2e_quantize(
logging.info("No quantizer provided, passing...")
return self

def export_to_edge(self) -> "LlamaEdgeManager":
def export_to_edge(self) -> "LLMEdgeManager":
"""
Export the model to Edge dialect and retrieve a LlamaEdgeManager.
Export the model to Edge dialect and retrieve a LLMEdgeManager.
"""
dynamic_shape = self._get_dynamic_shape()
edge_config = self._get_edge_config()
Expand All @@ -217,9 +213,7 @@ def export_to_edge(self) -> "LlamaEdgeManager":
)
return self

def to_backend(
self, partitioners: Optional[List[Partitioner]]
) -> "LlamaEdgeManager":
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
"""
Partition the model and lower to different backends. The signature is
aligned with the signature of `to_backend` method of EdgeManager.
Expand Down Expand Up @@ -249,7 +243,7 @@ def to_backend(

return self

def to_executorch(self) -> "LlamaEdgeManager":
def to_executorch(self) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
Expand Down
Loading