Skip to content

Commit f787604

Browse files
committed
Update phi-3-mini to use the export library
Pull Request resolved: #4190 as title. Note: exporting model from HF is currently blocked by this issue, https://fburl.com/qpaapajr. We need to apply the following trick to export the model. ``` model = torch.export._trace._export( model, example_inputs, dynamic_shapes=dynamic_shape, strict=False, pre_dispatch=False, ) ``` ghstack-source-id: 233317600 @exported-using-ghexport Differential Revision: [D59503255](https://our.internmc.facebook.com/intern/diff/D59503255/)
1 parent ead9d0a commit f787604

File tree

1 file changed

+28
-47
lines changed

1 file changed

+28
-47
lines changed

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,66 +5,47 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
89

9-
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
10-
DuplicateDynamicQuantChainPass,
11-
)
12-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
14-
from executorch.exir import to_edge
15-
from torch._export import capture_pre_autograd_graph
16-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
17-
18-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
19-
get_symmetric_quantization_config,
20-
XNNPACKQuantizer,
10+
from executorch.extension.llm.export.partitioner_lib import get_xnnpack_partitioner
11+
from executorch.extension.llm.export.quantizer_lib import (
12+
DynamicQuantLinearOptions,
13+
get_pt2e_quantizers,
14+
PT2EQuantOptions,
2115
)
2216

2317
from transformers import Phi3ForCausalLM
2418

2519

2620
def main() -> None:
27-
torch.random.manual_seed(0)
21+
torch.manual_seed(42)
2822

2923
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
3024
model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
3125

32-
example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),)
33-
dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}}
34-
35-
xnnpack_quant_config = get_symmetric_quantization_config(
36-
is_per_channel=True, is_dynamic=True
37-
)
38-
xnnpack_quantizer = XNNPACKQuantizer()
39-
xnnpack_quantizer.set_global(xnnpack_quant_config)
40-
41-
with torch.nn.attention.sdpa_kernel(
42-
[torch.nn.attention.SDPBackend.MATH]
43-
), torch.no_grad():
44-
model = capture_pre_autograd_graph(
45-
model, example_inputs, dynamic_shapes=dynamic_shape
26+
modelname = "phi-3-mini"
27+
28+
(
29+
LLMEdgeManager(
30+
model=model,
31+
modelname=modelname,
32+
max_seq_len=128,
33+
dtype=DType.fp32,
34+
use_kv_cache=False,
35+
example_inputs=(torch.randint(0, 100, (1, 100), dtype=torch.long),),
36+
enable_dynamic_shape=True,
37+
verbose=True,
4638
)
47-
model = prepare_pt2e(model, xnnpack_quantizer)
48-
model(*example_inputs)
49-
model = convert_pt2e(model, fold_quantize=False)
50-
DuplicateDynamicQuantChainPass()(model)
51-
# TODO(lunwenh): update it to use export once
52-
# https://github.com/pytorch/pytorch/issues/128394 is resolved.
53-
model = torch.export._trace._export(
54-
model,
55-
example_inputs,
56-
dynamic_shapes=dynamic_shape,
57-
strict=False,
58-
pre_dispatch=False,
39+
.set_output_dir(".")
40+
.capture_pre_autograd_graph()
41+
.pt2e_quantize(
42+
get_pt2e_quantizers(PT2EQuantOptions(None, DynamicQuantLinearOptions()))
5943
)
60-
61-
edge_config = get_xnnpack_edge_compile_config()
62-
edge_manager = to_edge(model, compile_config=edge_config)
63-
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
64-
et_program = edge_manager.to_executorch()
65-
66-
with open("phi-3-mini.pte", "wb") as file:
67-
file.write(et_program.buffer)
44+
.export_to_edge()
45+
.to_backend([get_xnnpack_partitioner()])
46+
.to_executorch()
47+
.save_to_pte(f"{modelname}.pte")
48+
)
6849

6950

7051
if __name__ == "__main__":

0 commit comments

Comments
 (0)