|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import torch
|
| 8 | +from executorch.extension.llm.export.builder import DType, LLMEdgeManager |
8 | 9 |
|
9 |
| -from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( |
10 |
| - DuplicateDynamicQuantChainPass, |
11 |
| -) |
12 |
| -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
13 |
| -from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config |
14 |
| -from executorch.exir import to_edge |
15 |
| -from torch._export import capture_pre_autograd_graph |
16 |
| -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
17 |
| - |
18 |
| -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
19 |
| - get_symmetric_quantization_config, |
20 |
| - XNNPACKQuantizer, |
| 10 | +from executorch.extension.llm.export.partitioner_lib import get_xnnpack_partitioner |
| 11 | +from executorch.extension.llm.export.quantizer_lib import ( |
| 12 | + DynamicQuantLinearOptions, |
| 13 | + get_pt2e_quantizers, |
| 14 | + PT2EQuantOptions, |
21 | 15 | )
|
22 | 16 |
|
23 | 17 | from transformers import Phi3ForCausalLM
|
24 | 18 |
|
25 | 19 |
|
26 | 20 | def main() -> None:
|
27 |
| - torch.random.manual_seed(0) |
| 21 | + torch.manual_seed(42) |
28 | 22 |
|
29 | 23 | model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
30 | 24 |
|
31 |
| - example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),) |
32 |
| - dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}} |
33 |
| - |
34 |
| - xnnpack_quant_config = get_symmetric_quantization_config( |
35 |
| - is_per_channel=True, is_dynamic=True |
36 |
| - ) |
37 |
| - xnnpack_quantizer = XNNPACKQuantizer() |
38 |
| - xnnpack_quantizer.set_global(xnnpack_quant_config) |
39 |
| - |
40 |
| - with torch.nn.attention.sdpa_kernel( |
41 |
| - [torch.nn.attention.SDPBackend.MATH] |
42 |
| - ), torch.no_grad(): |
43 |
| - model = capture_pre_autograd_graph( |
44 |
| - model, example_inputs, dynamic_shapes=dynamic_shape |
| 25 | + modelname = "phi-3-mini" |
| 26 | + |
| 27 | + ( |
| 28 | + LLMEdgeManager( |
| 29 | + model=model, |
| 30 | + modelname=modelname, |
| 31 | + max_seq_len=128, |
| 32 | + dtype=DType.fp32, |
| 33 | + use_kv_cache=False, |
| 34 | + example_inputs=(torch.randint(0, 100, (1, 100), dtype=torch.long),), |
| 35 | + enable_dynamic_shape=True, |
| 36 | + verbose=True, |
45 | 37 | )
|
46 |
| - model = prepare_pt2e(model, xnnpack_quantizer) |
47 |
| - model(*example_inputs) |
48 |
| - model = convert_pt2e(model, fold_quantize=False) |
49 |
| - DuplicateDynamicQuantChainPass()(model) |
50 |
| - # TODO(lunwenh): update it to use export once |
51 |
| - # https://github.com/pytorch/pytorch/issues/128394 is resolved. |
52 |
| - model = torch.export._trace._export( |
53 |
| - model, |
54 |
| - example_inputs, |
55 |
| - dynamic_shapes=dynamic_shape, |
56 |
| - strict=False, |
57 |
| - pre_dispatch=False, |
| 38 | + .set_output_dir(".") |
| 39 | + .capture_pre_autograd_graph() |
| 40 | + .pt2e_quantize( |
| 41 | + get_pt2e_quantizers(PT2EQuantOptions(None, DynamicQuantLinearOptions())) |
58 | 42 | )
|
59 |
| - |
60 |
| - edge_config = get_xnnpack_edge_compile_config() |
61 |
| - edge_manager = to_edge(model, compile_config=edge_config) |
62 |
| - edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True)) |
63 |
| - et_program = edge_manager.to_executorch() |
64 |
| - |
65 |
| - with open("phi-3-mini.pte", "wb") as file: |
66 |
| - file.write(et_program.buffer) |
| 43 | + .export_to_edge() |
| 44 | + .to_backend([get_xnnpack_partitioner()]) |
| 45 | + .to_executorch() |
| 46 | + .save_to_pte(f"{modelname}.pte") |
| 47 | + ) |
67 | 48 |
|
68 | 49 |
|
69 | 50 | if __name__ == "__main__":
|
|
0 commit comments