Skip to content

Commit c4a1e95

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Refactor export_llama_lib (#2030)
Summary: Pull Request resolved: #2030 Separate out "receipe" code and actual "cooking" code. Introduces a new `LlamaBuilder` class that handles internal logic of exporting. Takes source transforms, quantizers, and partitioners. Reviewed By: mergennachin Differential Revision: D54027081 fbshipit-source-id: 8da52b7a538331389eb148bff1b143373fcefb18
1 parent ca6995b commit c4a1e95

File tree

4 files changed

+394
-165
lines changed

4 files changed

+394
-165
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ runtime.python_binary(
4949
runtime.python_library(
5050
name = "export_library",
5151
srcs = [
52+
"builder.py",
5253
"export_llama.py",
5354
"export_llama_lib.py",
5455
],

examples/models/llama2/builder.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Providing builders for Llama2 models. These builders help user to build Llama2
8+
# eager models, apply source transformations and quantization and export them to
9+
# ExecuTorch.
10+
11+
import json
12+
import logging
13+
from enum import Enum
14+
from json import JSONDecodeError
15+
from typing import Any, Callable, Dict, List, Optional, Union
16+
17+
import torch
18+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
19+
DuplicateDynamicQuantChainPass,
20+
)
21+
from executorch.exir import EdgeProgramManager
22+
from executorch.exir.backend.partitioner import Partitioner
23+
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
24+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
25+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
26+
from torch._export import capture_pre_autograd_graph
27+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
28+
from torch.ao.quantization.quantizer import Quantizer
29+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
30+
from torch.nn.attention import SDPBackend
31+
32+
from ...portable.utils import export_to_edge, save_pte_program
33+
from ..model_factory import EagerModelFactory
34+
35+
36+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
37+
logging.basicConfig(level=logging.INFO, format=FORMAT)
38+
39+
40+
class WeightType(Enum):
41+
LLAMA = "LLAMA"
42+
FAIRSEQ2 = "FAIRSEQ2"
43+
44+
45+
class DType(Enum):
46+
fp32 = "fp32"
47+
fp16 = "fp16"
48+
49+
50+
def load_llama_model(
51+
*,
52+
checkpoint: str,
53+
params_path: str,
54+
use_kv_cache: bool = False,
55+
weight_type: WeightType = WeightType.LLAMA,
56+
verbose: bool = False,
57+
) -> "LlamaEdgeManager":
58+
"""
59+
A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that
60+
can help further lower the model to ExecuTorch.
61+
Returns:
62+
An instance of LlamaEdgeManager which contains the eager mode model.
63+
"""
64+
assert checkpoint and params_path, "Both checkpoint and params can't be empty"
65+
logging.info(
66+
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
67+
)
68+
model, example_inputs, _ = EagerModelFactory.create_model(
69+
"llama2",
70+
"Llama2Model",
71+
checkpoint=checkpoint,
72+
params=params_path,
73+
use_kv_cache=use_kv_cache,
74+
fairseq2=weight_type == WeightType.FAIRSEQ2,
75+
)
76+
state_dict = model.state_dict()
77+
dtype = state_dict[next(iter(state_dict))].dtype
78+
assert dtype in [torch.float16, torch.float32], "Only support fp16 or fp32"
79+
logging.info(f"Loaded model with dtype={dtype}")
80+
81+
return LlamaEdgeManager(
82+
model=model,
83+
weight_type=weight_type,
84+
dtype=DType.fp16 if dtype == torch.float16 else DType.fp32,
85+
use_kv_cache=use_kv_cache,
86+
example_inputs=example_inputs,
87+
verbose=verbose,
88+
)
89+
90+
91+
class LlamaEdgeManager:
92+
"""
93+
Host a torch.nn.Module for Llama model and facilitates exporting to ExecuTorch.
94+
"""
95+
96+
def __init__(
97+
self,
98+
model,
99+
weight_type,
100+
dtype,
101+
use_kv_cache,
102+
example_inputs,
103+
verbose: bool = False,
104+
):
105+
self.model = model
106+
self.weight_type = weight_type
107+
self.dtype = dtype
108+
self.example_inputs = example_inputs
109+
self.use_kv_cache = use_kv_cache
110+
self.metadata = None
111+
self.verbose = verbose
112+
self.applied_source_transforms = []
113+
self.edge_manager: Optional[EdgeProgramManager] = None
114+
self.export_program = None
115+
self.output_dir = "."
116+
117+
def set_metadata(self, metadata: Optional[dict]) -> "LlamaEdgeManager":
118+
"""
119+
Set the metadata that will be serialized into .pte file.
120+
Args:
121+
metadata (Optional[dict]): Metadata for the model.
122+
"""
123+
self.metadata = metadata
124+
return self
125+
126+
def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager":
127+
"""
128+
Set the directory where the .pte file will be saved.
129+
Args:
130+
output_dir (str): The directory to store the .pte file.
131+
"""
132+
self.output_dir = output_dir
133+
return self
134+
135+
def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
136+
"""
137+
Convert the model to the specified dtype.
138+
Args:
139+
dtype_override (Optional[DType]): Override the dtype of the model.
140+
"""
141+
assert not dtype_override or isinstance(
142+
dtype_override, DType
143+
), "Override dtype needs to be of type <DType>"
144+
if dtype_override == DType.fp16 and self.dtype != DType.fp16:
145+
logging.info("model.to torch.float16")
146+
self.model = self.model.to(dtype=torch.float16)
147+
self.dtype = dtype_override
148+
elif dtype_override == DType.fp32 and self.dtype != DType.fp32:
149+
logging.info("model.to torch.float32")
150+
self.model = self.model.to(dtype=torch.float32)
151+
self.dtype = dtype_override
152+
return self
153+
154+
def source_transform(
155+
self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
156+
) -> "LlamaEdgeManager":
157+
"""
158+
Apply source transforms to the model. The transforms are callables that
159+
takes nn.Module as input and returns nn.Module.
160+
Args:
161+
transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A
162+
list of source transforms.
163+
"""
164+
for transform in transforms:
165+
self.model = transform(self.model)
166+
self.applied_source_transforms.extend(transforms)
167+
168+
if self.verbose:
169+
logging.info(f"Applied source transforms: {self.applied_source_transforms}")
170+
return self
171+
172+
def _get_dynamic_shape(self) -> Optional[Dict[str, Any]]:
173+
if self.use_kv_cache:
174+
return None
175+
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
176+
dynamic_shape = {"tokens": {1: dim}}
177+
return dynamic_shape
178+
179+
def _get_edge_config(self) -> EdgeCompileConfig:
180+
edge_config = EdgeCompileConfig(
181+
_check_ir_validity=False,
182+
_skip_type_promotion=bool(self.dtype == DType.fp16),
183+
)
184+
return edge_config
185+
186+
def _get_metadata(self):
187+
params = self.model.params
188+
is_fairseq2 = self.weight_type == WeightType.FAIRSEQ2
189+
metadata = {
190+
"append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt.
191+
"get_bos_id": 3 if is_fairseq2 else 1,
192+
"get_dtype": 5 if self.dtype == DType.fp16 else 6,
193+
"get_eos_id": 3 if is_fairseq2 else 2,
194+
"get_head_dim": params.dim // params.n_heads,
195+
"get_max_batch_size": params.max_batch_size,
196+
"get_max_seq_len": params.max_seq_len,
197+
"get_n_bos": 1,
198+
"get_n_eos": 2 if is_fairseq2 else 1,
199+
"get_n_kv_heads": params.n_kv_heads,
200+
"get_n_layers": params.n_layers,
201+
"get_vocab_size": params.vocab_size,
202+
"use_kv_cache": self.use_kv_cache,
203+
}
204+
if self.metadata:
205+
try:
206+
extra = json.loads(self.metadata)
207+
for k, v in extra.items():
208+
metadata[k] = v
209+
except JSONDecodeError:
210+
logging.error("Invalid metadata, should be a valid JSON string")
211+
self.metadata = metadata
212+
return self.metadata
213+
214+
def export_to_edge(
215+
self, quantizers: Optional[List[Quantizer]]
216+
) -> "LlamaEdgeManager":
217+
"""
218+
Export the model to Edge dialect and retrieve a EdgeManager.
219+
Args:
220+
quantizers (Optional[List[Quantizer]]): A list of quantizers.
221+
"""
222+
dynamic_shape = self._get_dynamic_shape()
223+
edge_config = self._get_edge_config()
224+
metadata = self._get_metadata()
225+
226+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
227+
m = capture_pre_autograd_graph(
228+
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
229+
)
230+
if quantizers:
231+
if self.verbose:
232+
logging.info(f"Applied quantizers: {quantizers}")
233+
composed_quantizer = ComposableQuantizer(quantizers)
234+
m = prepare_pt2e(m, composed_quantizer)
235+
# Calibrate
236+
m(*self.example_inputs)
237+
m = convert_pt2e(m)
238+
DuplicateDynamicQuantChainPass()(m)
239+
self.edge_manager = export_to_edge(
240+
m,
241+
self.example_inputs,
242+
dynamic_shapes=dynamic_shape,
243+
edge_constant_methods=metadata,
244+
edge_compile_config=edge_config,
245+
verbose=True,
246+
)
247+
return self
248+
249+
def to_backend(
250+
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
251+
) -> "LlamaEdgeManager":
252+
"""
253+
Partition the model and lower to different backends. The signature is
254+
aligned with the signature of `to_backend` method of EdgeManager.
255+
Args:
256+
partitioner (Union[Partitioner, Dict[str, Partitioner]]): One or more
257+
partitioner to be sent to EdgeManager.to_backend().
258+
"""
259+
assert self.edge_manager is not None, "Need to run export_to_edge() first"
260+
if isinstance(partitioner, dict):
261+
for key, p in partitioner.items():
262+
assert self.edge_manager is not None
263+
self.edge_manager = self.edge_manager.to_backend(p)
264+
if self.verbose:
265+
logging.info(f"Applied partitioners: {key}")
266+
elif isinstance(partitioner, Partitioner):
267+
assert self.edge_manager is not None
268+
self.edge_manager = self.edge_manager.to_backend(partitioner)
269+
if self.verbose:
270+
logging.info(f"Applied partitioners: {partitioner}")
271+
else:
272+
logging.warning("Invalid partitioner, skipping...")
273+
return self
274+
275+
def to_executorch(self) -> "LlamaEdgeManager":
276+
"""
277+
Lower the model to executorch and get an ExecutorchProgram.
278+
"""
279+
assert self.edge_manager, "Need to run export_to_edge() first"
280+
self.export_program = self.edge_manager.to_executorch(
281+
ExecutorchBackendConfig(
282+
extract_constant_segment=True,
283+
extract_delegate_segments=True,
284+
passes=[
285+
QuantFusionPass(),
286+
],
287+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
288+
)
289+
)
290+
logging.info(
291+
"Required memory for activation in bytes: {}".format(
292+
self.export_program._emitter_output.program.execution_plan[
293+
0
294+
].non_const_buffer_sizes
295+
),
296+
)
297+
return self
298+
299+
def save_to_pte(self, output_name: str) -> None:
300+
"""
301+
Save the model to a .pte file.
302+
Args:
303+
output_name (Optional[str]): The name of the .pte file.
304+
"""
305+
assert output_name, "Need a valid output name"
306+
save_pte_program(self.export_program.buffer, output_name, self.output_dir)

0 commit comments

Comments
 (0)