Skip to content

Commit e74a586

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Refactor export_llama_lib
Summary: Separate out "receipe" code and actual "cooking" code. Introduces a new `LlamaBuilder` class that handles internal logic of exporting. Takes source transforms, quantizers, and partitioners. Differential Revision: D54027081
1 parent 20714e7 commit e74a586

File tree

3 files changed

+306
-163
lines changed

3 files changed

+306
-163
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ runtime.python_binary(
4949
runtime.python_library(
5050
name = "export_library",
5151
srcs = [
52+
"builder.py",
5253
"export_llama.py",
5354
"export_llama_lib.py",
5455
],

examples/models/llama2/builder.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting Llama2 to flatbuffer
8+
9+
import json
10+
import logging
11+
from json import JSONDecodeError
12+
from typing import Callable, List, Optional
13+
14+
import pkg_resources
15+
import torch
16+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
17+
DuplicateDynamicQuantChainPass,
18+
)
19+
from executorch.exir.backend.partitioner import Partitioner
20+
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
21+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
22+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
23+
from torch._export import capture_pre_autograd_graph
24+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
25+
from torch.ao.quantization.quantizer import Quantizer
26+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
27+
from torch.nn.attention import SDPBackend
28+
29+
from ...portable.utils import export_to_edge, save_pte_program
30+
from ..model_factory import EagerModelFactory
31+
32+
33+
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
34+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
35+
logging.basicConfig(level=logging.INFO, format=FORMAT)
36+
37+
pkg_name = __name__
38+
39+
40+
def canonical_path(path: str, *, dir: bool = False) -> str:
41+
42+
print(f"creating canonical path for {path}")
43+
if not path.startswith("par:"):
44+
return path
45+
46+
if not IS_FBCODE:
47+
print("not FBCODE")
48+
return path[4:]
49+
else:
50+
return_val = pkg_resources.resource_filename(pkg_name, path[4:])
51+
print(f"canonical name is: {return_val}")
52+
return return_val
53+
54+
55+
class LlamaBuilder:
56+
"""
57+
A builder class that builds a Llama2 model, apply source transformation & quantization and export to Executorch.
58+
If you want to apply different quantization and source tranformation schemes, before modifying this file consider using these existing hooks:
59+
* .source_transform()
60+
* .export_to_edge()
61+
* .to_backend()
62+
"""
63+
def __init__(self, verbose: bool = False):
64+
self.verbose = verbose
65+
self.checkpoint = None
66+
self.params = None
67+
self.output_dir = "."
68+
self.is_fairseq2 = False
69+
self.use_kv_cache = False
70+
self.dynamic_shape = None
71+
self.model = None
72+
self.example_inputs = None
73+
self.dtype = None
74+
self.applied_source_transforms = []
75+
self.edge_manager = None
76+
self.output_name = "llama2"
77+
self.edge_config = None
78+
self.metadata = None
79+
self.export_program = None
80+
81+
def set_checkpoint(self, checkpoint: str, is_fairseq2=False):
82+
self.checkpoint = checkpoint
83+
self.is_fairseq2 = is_fairseq2
84+
return self
85+
86+
def set_params(self, params: str):
87+
self.params = params
88+
return self
89+
90+
def set_output_dir(self, output_dir: str):
91+
self.output_dir = output_dir
92+
return self
93+
94+
def set_use_kv_cache(self, use_kv_cache: bool):
95+
assert self.model is None, (
96+
"To ensure consistency, set_use_kv_cache can't be called after load_model()."
97+
f"Currently the model has use_kv_cache = {self.use_kv_cache}"
98+
)
99+
self.use_kv_cache = use_kv_cache
100+
return self
101+
102+
def set_metadata(self, metadata: Optional[dict]):
103+
self.metadata = metadata
104+
return self
105+
106+
def load_model(self):
107+
assert (
108+
self.checkpoint and self.params
109+
), "Both checkpoint and params needs to be set"
110+
if self.model:
111+
logging.info(f"Reloading model from {self.checkpoint} and {self.params}")
112+
checkpoint_path = canonical_path(self.checkpoint)
113+
params_path = canonical_path(self.params)
114+
logging.info(
115+
f"Loading model with checkpoint={checkpoint_path}, params={params_path}, use_kv_cache={self.use_kv_cache}, fairseq2={self.is_fairseq2}"
116+
)
117+
self.model, self.example_inputs, _ = EagerModelFactory.create_model(
118+
"llama2",
119+
"Llama2Model",
120+
checkpoint=checkpoint_path,
121+
params=params_path,
122+
use_kv_cache=self.use_kv_cache,
123+
fairseq2=self.is_fairseq2,
124+
)
125+
state_dict = self.model.state_dict()
126+
dtype = state_dict[next(iter(state_dict))].dtype
127+
assert dtype in [torch.float16, torch.float32], "Only support fp16 or fp32"
128+
logging.info(f"Loaded model with dtype={dtype}")
129+
self.dtype = "fp16" if dtype == torch.float16 else "fp32"
130+
131+
return self
132+
133+
def to_dtype(self, dtype_override: Optional[str]):
134+
assert self.model, "Need to run load_model() first"
135+
assert not dtype_override or dtype_override in [
136+
"fp16",
137+
"fp32",
138+
], "Only support fp16 or fp32"
139+
140+
if dtype_override == "fp16" and self.dtype != "fp16":
141+
logging.info("model.to torch.float16")
142+
self.model = self.model.to(dtype=torch.float16)
143+
self.dtype = dtype_override
144+
elif dtype_override == "fp32" and self.dtype != "fp32":
145+
logging.info("model.to torch.float32")
146+
self.model = self.model.to(dtype=torch.float32)
147+
self.dtype = dtype_override
148+
return self
149+
150+
def source_transform(
151+
self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
152+
):
153+
assert self.model, "Need to run load_model() first"
154+
for transform in transforms:
155+
self.model = transform(self.model)
156+
self.applied_source_transforms.extend(transforms)
157+
158+
if self.verbose:
159+
logging.info(f"{self.output_name}:")
160+
logging.info(f"{self.model}")
161+
return self
162+
163+
def _get_dynamic_shape(self):
164+
assert self.model, "Need to run load_model() first"
165+
if self.use_kv_cache:
166+
return None
167+
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
168+
self.dynamic_shape = {"tokens": {1: dim}}
169+
return self.dynamic_shape
170+
171+
def _get_edge_config(self):
172+
self.edge_config = EdgeCompileConfig(
173+
_check_ir_validity=False,
174+
_skip_type_promotion=bool(self.dtype == "fp16"),
175+
)
176+
177+
def _get_metadata(self):
178+
assert self.model, "Need to run load_model() first"
179+
params = self.model.params
180+
metadata = {
181+
"append_eos_to_prompt": self.is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt.
182+
"get_bos_id": 3 if self.is_fairseq2 else 1,
183+
"get_dtype": 5 if self.dtype == "fp16" else 6,
184+
"get_eos_id": 3 if self.is_fairseq2 else 2,
185+
"get_head_dim": params.dim // params.n_heads,
186+
"get_max_batch_size": params.max_batch_size,
187+
"get_max_seq_len": params.max_seq_len,
188+
"get_n_bos": 1,
189+
"get_n_eos": 2 if self.is_fairseq2 else 1,
190+
"get_n_kv_heads": params.n_kv_heads,
191+
"get_n_layers": params.n_layers,
192+
"get_vocab_size": params.vocab_size,
193+
"use_kv_cache": self.use_kv_cache,
194+
}
195+
if self.metadata:
196+
try:
197+
extra = json.loads(self.metadata)
198+
for k, v in extra.items():
199+
metadata[k] = v
200+
except JSONDecodeError:
201+
logging.error("Invalid metadata, should be a valid JSON string")
202+
self.metadata = metadata
203+
return self.metadata
204+
205+
def export_to_edge(self, quantizers: Optional[List[Quantizer]]):
206+
assert self.model, "Need to run load_model() first"
207+
dynamic_shape = self._get_dynamic_shape()
208+
edge_config = self._get_edge_config()
209+
metadata = self._get_metadata()
210+
211+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
212+
m = capture_pre_autograd_graph(
213+
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
214+
)
215+
if quantizers:
216+
composed_quantizer = ComposableQuantizer(quantizers)
217+
m = prepare_pt2e(m, composed_quantizer)
218+
# Calibrate
219+
m(*self.example_inputs)
220+
m = convert_pt2e(m)
221+
DuplicateDynamicQuantChainPass()(m)
222+
self.edge_manager = export_to_edge(
223+
m,
224+
self.example_inputs,
225+
dynamic_shapes=dynamic_shape,
226+
edge_constant_methods=metadata,
227+
edge_compile_config=edge_config,
228+
)
229+
return self
230+
231+
def to_backend(self, partitioners: Optional[List[Partitioner]]):
232+
assert self.edge_manager, "Need to run export_to_edge() first"
233+
if partitioners:
234+
for partitioner in partitioners:
235+
self.edge_manager = self.edge_manager.to_backend(partitioner)
236+
return self
237+
238+
def to_executorch(self):
239+
assert self.edge_manager, "Need to run export_to_edge() first"
240+
self.export_program = self.edge_manager.to_executorch(
241+
ExecutorchBackendConfig(
242+
extract_constant_segment=True,
243+
extract_delegate_segments=True,
244+
passes=[
245+
QuantFusionPass(),
246+
],
247+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
248+
)
249+
)
250+
return self
251+
logging.info(
252+
"Required memory for activation in bytes: ",
253+
self.export_program._emitter_output.program.execution_plan[
254+
0
255+
].non_const_buffer_sizes,
256+
)
257+
258+
def save(self, output_name: Optional[str]):
259+
if output_name:
260+
self.output_name = output_name
261+
save_pte_program(self.export_program.buffer, self.output_name, self.output_dir)

0 commit comments

Comments
 (0)