Skip to content

Commit ba90926

Browse files
committed
Update phi-3-mini to use the export library
Pull Request resolved: #4190 as title. Note: exporting model from HF is currently blocked by this issue, https://fburl.com/qpaapajr. We need to apply the following trick to export the model. ``` model = torch.export._trace._export( model, example_inputs, dynamic_shapes=dynamic_shape, strict=False, pre_dispatch=False, ) ``` ghstack-source-id: 233306025 @exported-using-ghexport Differential Revision: [D59503255](https://our.internmc.facebook.com/intern/diff/D59503255/)
1 parent 6c01d2b commit ba90926

File tree

1 file changed

+28
-47
lines changed

1 file changed

+28
-47
lines changed

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
89

9-
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
10-
DuplicateDynamicQuantChainPass,
11-
)
12-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
14-
from executorch.exir import to_edge
15-
from torch._export import capture_pre_autograd_graph
16-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
17-
18-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
19-
get_symmetric_quantization_config,
20-
XNNPACKQuantizer,
10+
from executorch.extension.llm.export.partitioner_lib import get_xnnpack_partitioner
11+
from executorch.extension.llm.export.quantizer_lib import (
12+
DynamicQuantLinearOptions,
13+
get_pt2e_quantizers,
14+
PT2EQuantOptions,
2115
)
2216

2317
from transformers import ( # @manual=fbsource//third-party/pypi/transformers:transformers
@@ -26,46 +20,33 @@
2620

2721

2822
def main() -> None:
29-
torch.random.manual_seed(0)
23+
torch.manual_seed(42)
3024

3125
model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
3226

33-
example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),)
34-
dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}}
35-
36-
xnnpack_quant_config = get_symmetric_quantization_config(
37-
is_per_channel=True, is_dynamic=True
38-
)
39-
xnnpack_quantizer = XNNPACKQuantizer()
40-
xnnpack_quantizer.set_global(xnnpack_quant_config)
41-
42-
with torch.nn.attention.sdpa_kernel(
43-
[torch.nn.attention.SDPBackend.MATH]
44-
), torch.no_grad():
45-
model = capture_pre_autograd_graph(
46-
model, example_inputs, dynamic_shapes=dynamic_shape
27+
modelname = "phi-3-mini"
28+
29+
(
30+
LLMEdgeManager(
31+
model=model,
32+
modelname=modelname,
33+
max_seq_len=128,
34+
dtype=DType.fp32,
35+
use_kv_cache=False,
36+
example_inputs=(torch.randint(0, 100, (1, 100), dtype=torch.long),),
37+
enable_dynamic_shape=True,
38+
verbose=True,
4739
)
48-
model = prepare_pt2e(model, xnnpack_quantizer)
49-
model(*example_inputs)
50-
model = convert_pt2e(model, fold_quantize=False)
51-
DuplicateDynamicQuantChainPass()(model)
52-
# TODO(lunwenh): update it to use export once
53-
# https://github.com/pytorch/pytorch/issues/128394 is resolved.
54-
model = torch.export._trace._export(
55-
model,
56-
example_inputs,
57-
dynamic_shapes=dynamic_shape,
58-
strict=False,
59-
pre_dispatch=False,
40+
.set_output_dir(".")
41+
.capture_pre_autograd_graph()
42+
.pt2e_quantize(
43+
get_pt2e_quantizers(PT2EQuantOptions(None, DynamicQuantLinearOptions()))
6044
)
61-
62-
edge_config = get_xnnpack_edge_compile_config()
63-
edge_manager = to_edge(model, compile_config=edge_config)
64-
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
65-
et_program = edge_manager.to_executorch()
66-
67-
with open("phi-3-mini.pte", "wb") as file:
68-
file.write(et_program.buffer)
45+
.export_to_edge()
46+
.to_backend([get_xnnpack_partitioner()])
47+
.to_executorch()
48+
.save_to_pte(f"{modelname}.pte")
49+
)
6950

7051

7152
if __name__ == "__main__":

0 commit comments

Comments
 (0)