|
12 | 12 | from huggingface_hub import hf_hub_download
|
13 | 13 | from moshi.models import loaders
|
14 | 14 | from torch.export import export, ExportedProgram
|
| 15 | +from executorch.exir import to_edge_transform_and_lower |
| 16 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
| 17 | +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( |
| 18 | + get_symmetric_quantization_config, |
| 19 | + XNNPACKQuantizer, |
| 20 | +) |
| 21 | +from torch.ao.quantization.quantize_pt2e import ( |
| 22 | + _convert_to_reference_decomposed_fx, |
| 23 | + convert_pt2e, |
| 24 | + prepare_pt2e, |
| 25 | + prepare_qat_pt2e, |
| 26 | +) |
| 27 | + |
| 28 | +from executorch.extension.export_util.utils import save_pte_program |
15 | 29 |
|
16 | 30 |
|
17 | 31 | def read_mp3_from_url(url):
|
@@ -131,6 +145,35 @@ def forward(self, x):
|
131 | 145 | ep_decode_output = exported_decode.module()(input)
|
132 | 146 | self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
|
133 | 147 |
|
| 148 | + # uncomment below for pt2e quantized model |
| 149 | + # # PT2E Quantization |
| 150 | + # quantizer = XNNPACKQuantizer() |
| 151 | + # # 8 bit by default |
| 152 | + # quantization_config = get_symmetric_quantization_config( |
| 153 | + # is_per_channel=True, |
| 154 | + # is_dynamic=True, |
| 155 | + # ) |
| 156 | + # quantizer.set_global(quantization_config) |
| 157 | + # m = exported_decode.module() |
| 158 | + # m = prepare_pt2e(m, quantizer) |
| 159 | + # m(input) |
| 160 | + # m = convert_pt2e(m) |
| 161 | + # print("quantized graph:") |
| 162 | + # print(m.graph) |
| 163 | + # # Export quantized module |
| 164 | + # exported_decode: ExportedProgram = export(m, (input,), strict=False) |
| 165 | + |
| 166 | + # Lower |
| 167 | + edge_manager = to_edge_transform_and_lower( |
| 168 | + exported_decode, |
| 169 | + partitioner=[XnnpackPartitioner()], |
| 170 | + ) |
| 171 | + |
| 172 | + exec_prog = edge_manager.to_executorch() |
| 173 | + print("exec graph:") |
| 174 | + print(exec_prog.exported_program().graph) |
| 175 | + save_pte_program(exec_prog, "/tmp/Mimi_decode_fp.pte") |
| 176 | + |
134 | 177 | def test_exported_encoding(self):
|
135 | 178 | """Ensure exported encoding model is consistent with reference output."""
|
136 | 179 |
|
|
0 commit comments