Skip to content

Commit 7af428b

Browse files
committed
Update phi-3-mini to use the export library
as title. Note: exporting model from HF is currently blocked by this issue, https://fburl.com/qpaapajr. We need to apply the following trick to export the model. ``` model = torch.export._trace._export( model, example_inputs, dynamic_shapes=dynamic_shape, strict=False, pre_dispatch=False, ) ``` Differential Revision: [D59503255](https://our.internmc.facebook.com/intern/diff/D59503255/) ghstack-source-id: 233088631 Pull Request resolved: #4190
1 parent d2d2b11 commit 7af428b

File tree

1 file changed

+28
-47
lines changed

1 file changed

+28
-47
lines changed

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,46 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
89

9-
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
10-
DuplicateDynamicQuantChainPass,
11-
)
12-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
14-
from executorch.exir import to_edge
15-
from torch._export import capture_pre_autograd_graph
16-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
17-
18-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
19-
get_symmetric_quantization_config,
20-
XNNPACKQuantizer,
10+
from executorch.extension.llm.export.partitioner_lib import get_xnnpack_partitioner
11+
from executorch.extension.llm.export.quantizer_lib import (
12+
DynamicQuantLinearOptions,
13+
get_pt2e_quantizers,
14+
PT2EQuantOptions,
2115
)
2216

2317
from transformers import Phi3ForCausalLM
2418

2519

2620
def main() -> None:
27-
torch.random.manual_seed(0)
21+
torch.manual_seed(42)
2822

2923
model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
3024

31-
example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),)
32-
dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}}
33-
34-
xnnpack_quant_config = get_symmetric_quantization_config(
35-
is_per_channel=True, is_dynamic=True
36-
)
37-
xnnpack_quantizer = XNNPACKQuantizer()
38-
xnnpack_quantizer.set_global(xnnpack_quant_config)
39-
40-
with torch.nn.attention.sdpa_kernel(
41-
[torch.nn.attention.SDPBackend.MATH]
42-
), torch.no_grad():
43-
model = capture_pre_autograd_graph(
44-
model, example_inputs, dynamic_shapes=dynamic_shape
25+
modelname = "phi-3-mini"
26+
27+
(
28+
LLMEdgeManager(
29+
model=model,
30+
modelname=modelname,
31+
max_seq_len=128,
32+
dtype=DType.fp32,
33+
use_kv_cache=False,
34+
example_inputs=(torch.randint(0, 100, (1, 100), dtype=torch.long),),
35+
enable_dynamic_shape=True,
36+
verbose=True,
4537
)
46-
model = prepare_pt2e(model, xnnpack_quantizer)
47-
model(*example_inputs)
48-
model = convert_pt2e(model, fold_quantize=False)
49-
DuplicateDynamicQuantChainPass()(model)
50-
# TODO(lunwenh): update it to use export once
51-
# https://github.com/pytorch/pytorch/issues/128394 is resolved.
52-
model = torch.export._trace._export(
53-
model,
54-
example_inputs,
55-
dynamic_shapes=dynamic_shape,
56-
strict=False,
57-
pre_dispatch=False,
38+
.set_output_dir(".")
39+
.capture_pre_autograd_graph()
40+
.pt2e_quantize(
41+
get_pt2e_quantizers(PT2EQuantOptions(None, DynamicQuantLinearOptions()))
5842
)
59-
60-
edge_config = get_xnnpack_edge_compile_config()
61-
edge_manager = to_edge(model, compile_config=edge_config)
62-
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
63-
et_program = edge_manager.to_executorch()
64-
65-
with open("phi-3-mini.pte", "wb") as file:
66-
file.write(et_program.buffer)
43+
.export_to_edge()
44+
.to_backend([get_xnnpack_partitioner()])
45+
.to_executorch()
46+
.save_to_pte(f"{modelname}.pte")
47+
)
6748

6849

6950
if __name__ == "__main__":

0 commit comments

Comments
 (0)