Skip to content

Commit 26a9cce

Browse files
author
Martin Yuan
committed
Export Mimi to xnnpack
1 parent 9a0c2db commit 26a9cce

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212
from huggingface_hub import hf_hub_download
1313
from moshi.models import loaders
1414
from torch.export import export, ExportedProgram
15+
from executorch.exir import to_edge_transform_and_lower
16+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
18+
get_symmetric_quantization_config,
19+
XNNPACKQuantizer,
20+
)
21+
from torch.ao.quantization.quantize_pt2e import (
22+
_convert_to_reference_decomposed_fx,
23+
convert_pt2e,
24+
prepare_pt2e,
25+
prepare_qat_pt2e,
26+
)
27+
28+
from executorch.extension.export_util.utils import save_pte_program
1529

1630

1731
def read_mp3_from_url(url):
@@ -131,6 +145,35 @@ def forward(self, x):
131145
ep_decode_output = exported_decode.module()(input)
132146
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
133147

148+
# uncomment below for pt2e quantized model
149+
# # PT2E Quantization
150+
# quantizer = XNNPACKQuantizer()
151+
# # 8 bit by default
152+
# quantization_config = get_symmetric_quantization_config(
153+
# is_per_channel=True,
154+
# is_dynamic=True,
155+
# )
156+
# quantizer.set_global(quantization_config)
157+
# m = exported_decode.module()
158+
# m = prepare_pt2e(m, quantizer)
159+
# m(input)
160+
# m = convert_pt2e(m)
161+
# print("quantized graph:")
162+
# print(m.graph)
163+
# # Export quantized module
164+
# exported_decode: ExportedProgram = export(m, (input,), strict=False)
165+
166+
# Lower
167+
edge_manager = to_edge_transform_and_lower(
168+
exported_decode,
169+
partitioner=[XnnpackPartitioner()],
170+
)
171+
172+
exec_prog = edge_manager.to_executorch()
173+
print("exec graph:")
174+
print(exec_prog.exported_program().graph)
175+
save_pte_program(exec_prog, "/tmp/Mimi_decode_fp.pte")
176+
134177
def test_exported_encoding(self):
135178
"""Ensure exported encoding model is consistent with reference output."""
136179

0 commit comments

Comments
 (0)