Skip to content

Commit a72afe2

Browse files
committed
Move builder.py into extension
This PR does the following things: - rename `LlamaEdgeManager` to `LLMEdgeManager` - clean up some unused parameters: `weight_type`, `use_sdpa_with_kv_cache` - move `model.params.max_seq_len` out of `builder.py` - move `builder.py` into `extension/llm/export` Differential Revision: [D59493975](https://our.internmc.facebook.com/intern/diff/D59493975/) [ghstack-poisoned]
1 parent 2d9c6b5 commit a72afe2

File tree

8 files changed

+37
-42
lines changed

8 files changed

+37
-42
lines changed

backends/transforms/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ runtime.python_library(
119119
visibility = [
120120
"//executorch/backends/...",
121121
"//executorch/examples/...",
122+
"//executorch/extension/llm/...",
122123
],
123124
deps = [
124125
"//caffe2:torch",

examples/models/llama2/TARGETS

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ runtime.python_binary(
6464
runtime.python_library(
6565
name = "export_library",
6666
srcs = [
67-
"builder.py",
6867
"export_llama.py",
6968
"export_llama_lib.py",
7069
"model.py",
@@ -82,13 +81,10 @@ runtime.python_library(
8281
],
8382
deps = [
8483
"//caffe2:torch",
85-
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
8684
"//executorch/examples/models:model_base",
8785
"//executorch/examples/models:models",
8886
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
89-
"//executorch/exir:lib",
9087
"//executorch/extension/llm/export:export_lib",
91-
"//executorch/extension/export_util:export_util",
9288
# one definition has to be included in the user of the libarary
9389
# depending on what library the client wants to use
9490
# "//executorch/extension/pybindings:aten_lib",

examples/models/llama2/eval_llama_lib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
Tokenizer as SentencePieceTokenizer,
2020
)
2121

22+
from executorch.extension.llm.export import LLMEdgeManager
23+
2224
from lm_eval.api.model import LM
2325

24-
from .builder import LlamaEdgeManager
2526
from .export_llama_lib import (
2627
_prepare_for_llama_export,
2728
build_args_parser as _build_args_parser,
@@ -130,7 +131,7 @@ def gen_eval_wrapper(
130131

131132
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
132133
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
133-
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
134+
manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args)
134135

135136
if len(quantizers) != 0:
136137
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from executorch.examples.models.llama2.llama_transformer import ModelArgs
2424

25+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
26+
2527
from executorch.extension.llm.export.partitioner_lib import (
2628
get_coreml_partitioner,
2729
get_mps_partitioner,
@@ -40,8 +42,6 @@
4042
from executorch.util.activation_memory_profiler import generate_memory_trace
4143

4244
from ..model_factory import EagerModelFactory
43-
44-
from .builder import DType, LlamaEdgeManager
4545
from .source_transformation.quantize import (
4646
get_quant_embedding_transform,
4747
get_quant_weight_transform,
@@ -333,12 +333,12 @@ def export_llama(modelname, args) -> str:
333333
return filename
334334

335335

336-
def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
336+
def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
337337
"""
338338
Helper function for export_llama. Loads the model from checkpoint and params,
339-
and sets up a LlamaEdgeManager with initial transforms and dtype conversion.
339+
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
340340
341-
Returns a LlamaEdgeManager prior to calling export_to_edge with quantizers
341+
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
342342
"""
343343

344344
# load model from checkpoint and params.json
@@ -429,7 +429,7 @@ def _validate_args(args):
429429
)
430430

431431

432-
def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901
432+
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
433433
_validate_args(args)
434434
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
435435

@@ -579,12 +579,12 @@ def _load_llama_model(
579579
verbose: bool = False,
580580
max_seq_len: int = 128,
581581
metadata_str: Optional[str] = None,
582-
) -> "LlamaEdgeManager":
582+
) -> "LLMEdgeManager":
583583
"""
584-
A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that
584+
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
585585
can help further lower the model to ExecuTorch.
586586
Returns:
587-
An instance of LlamaEdgeManager which contains the eager mode model.
587+
An instance of LLMEdgeManager which contains the eager mode model.
588588
"""
589589
assert (
590590
checkpoint or checkpoint_dir
@@ -622,13 +622,12 @@ def _load_llama_model(
622622
else:
623623
raise ValueError(f"Unsupported dtype {dtype}")
624624

625-
return LlamaEdgeManager(
625+
return LLMEdgeManager(
626626
model=model,
627627
modelname=modelname,
628-
weight_type=weight_type,
628+
max_seq_len=model.params.max_seq_len,
629629
dtype=dtype,
630630
use_kv_cache=use_kv_cache,
631-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
632631
example_inputs=example_inputs,
633632
enable_dynamic_shape=enable_dynamic_shape,
634633
verbose=verbose,

examples/models/llama2/source_transformation/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import torch.nn as nn
1313
import torch.nn.functional as F
1414

15-
from sentencepiece import SentencePieceProcessor
15+
from executorch.extension.llm.export.builder import DType
1616

17-
from ..builder import DType
17+
from sentencepiece import SentencePieceProcessor
1818

1919
try:
2020
from fairseq2.nn.embedding import (

examples/qualcomm/llama2/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
generate_htp_compiler_spec,
3131
generate_qnn_executorch_compiler_spec,
3232
)
33-
from executorch.examples.models.llama2.builder import DType
3433
from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs
3534
from executorch.examples.qualcomm.scripts.utils import (
3635
make_output_dir,
@@ -41,6 +40,7 @@
4140
from executorch.exir.capture._config import ExecutorchBackendConfig
4241
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4342
from executorch.exir.program._program import _get_updated_graph_signature
43+
from executorch.extension.llm.export.builder import DType
4444

4545
from sentencepiece import SentencePieceProcessor
4646
from torch.ao.quantization.observer import MinMaxObserver

extension/llm/export/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ define_common_targets()
1111
runtime.python_library(
1212
name = "export_lib",
1313
srcs = [
14+
"builder.py",
1415
"partitioner_lib.py",
1516
"quantizer_lib.py",
1617
],
@@ -26,8 +27,11 @@ runtime.python_library(
2627
"//executorch/backends/apple/coreml:backend",
2728
"//executorch/backends/apple/coreml:partitioner",
2829
"//executorch/backends/apple/mps:partitioner",
30+
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
2931
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
3032
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
33+
"//executorch/exir:lib",
3134
"//executorch/exir/backend:backend_details",
35+
"//executorch/extension/export_util:export_util",
3236
],
3337
)

examples/models/llama2/builder.py renamed to extension/llm/export/builder.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,18 @@ def to_torch_dtype(self) -> torch.dtype:
5252
return mapping[self]
5353

5454

55-
class LlamaEdgeManager:
55+
class LLMEdgeManager:
5656
"""
57-
Host a torch.nn.Module for Llama model and facilitates exporting to ExecuTorch.
57+
Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
5858
"""
5959

6060
def __init__(
6161
self,
6262
model,
6363
modelname,
64-
weight_type,
64+
max_seq_len,
6565
dtype,
6666
use_kv_cache,
67-
use_sdpa_with_kv_cache,
6867
example_inputs,
6968
enable_dynamic_shape: bool = False,
7069
verbose: bool = False,
@@ -74,12 +73,11 @@ def __init__(
7473
# graph module returned from capture_pre_autograd_graph
7574
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
7675
self.modelname = modelname
77-
self.weight_type = weight_type
76+
self.max_seq_len = max_seq_len
7877
self.dtype = dtype
7978
self.example_inputs = example_inputs
8079
self.use_kv_cache = use_kv_cache
8180
self.enable_dynamic_shape = enable_dynamic_shape
82-
self.use_sdpa_with_kv_cache = use_sdpa_with_kv_cache
8381
self.verbose = verbose
8482
self.metadata = metadata
8583
self.applied_source_transforms = []
@@ -88,7 +86,7 @@ def __init__(
8886
self.output_dir = "."
8987
self._saved_pte_filename = None
9088

91-
def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager":
89+
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
9290
"""
9391
Set the directory where the .pte file will be saved.
9492
Args:
@@ -97,7 +95,7 @@ def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager":
9795
self.output_dir = output_dir
9896
return self
9997

100-
def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
98+
def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
10199
"""
102100
Convert the model to the specified dtype.
103101
Args:
@@ -115,7 +113,7 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
115113

116114
def source_transform(
117115
self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
118-
) -> "LlamaEdgeManager":
116+
) -> "LLMEdgeManager":
119117
"""
120118
Apply source transforms to the model. The transforms are callables that
121119
takes nn.Module as input and returns nn.Module.
@@ -132,7 +130,7 @@ def source_transform(
132130
return self
133131

134132
def _get_dynamic_shape(self) -> Any:
135-
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
133+
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
136134
if self.use_kv_cache:
137135
if self.enable_dynamic_shape:
138136
return ({1: dim}, {0: dim})
@@ -149,7 +147,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
149147
)
150148
return edge_config
151149

152-
def capture_pre_autograd_graph(self) -> "LlamaEdgeManager":
150+
def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
153151
dynamic_shape = self._get_dynamic_shape()
154152
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
155153
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -159,11 +157,9 @@ def capture_pre_autograd_graph(self) -> "LlamaEdgeManager":
159157
)
160158
return self
161159

162-
def pt2e_quantize(
163-
self, quantizers: Optional[List[Quantizer]]
164-
) -> "LlamaEdgeManager":
160+
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
165161
"""
166-
Quantize the model via pt2e flow and retrieve LlamaEdgeManager including the quantized model.
162+
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
167163
Args:
168164
quantizers (Optional[List[Quantizer]]): A list of quantizers.
169165
"""
@@ -193,9 +189,9 @@ def pt2e_quantize(
193189
logging.info("No quantizer provided, passing...")
194190
return self
195191

196-
def export_to_edge(self) -> "LlamaEdgeManager":
192+
def export_to_edge(self) -> "LLMEdgeManager":
197193
"""
198-
Export the model to Edge dialect and retrieve a LlamaEdgeManager.
194+
Export the model to Edge dialect and retrieve a LLMEdgeManager.
199195
"""
200196
dynamic_shape = self._get_dynamic_shape()
201197
edge_config = self._get_edge_config()
@@ -217,9 +213,7 @@ def export_to_edge(self) -> "LlamaEdgeManager":
217213
)
218214
return self
219215

220-
def to_backend(
221-
self, partitioners: Optional[List[Partitioner]]
222-
) -> "LlamaEdgeManager":
216+
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
223217
"""
224218
Partition the model and lower to different backends. The signature is
225219
aligned with the signature of `to_backend` method of EdgeManager.
@@ -249,7 +243,7 @@ def to_backend(
249243

250244
return self
251245

252-
def to_executorch(self) -> "LlamaEdgeManager":
246+
def to_executorch(self) -> "LLMEdgeManager":
253247
"""
254248
Lower the model to executorch and get an ExecutorchProgram.
255249
"""

0 commit comments

Comments
 (0)