Skip to content

Commit e904152

Browse files
author
Martin Yuan
committed
Export Mimi to xnnpack
1 parent fc6d86e commit e904152

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,23 @@
88
import torch
99
import torch.nn as nn
1010
import torchaudio
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
12+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
13+
get_symmetric_quantization_config,
14+
XNNPACKQuantizer,
15+
)
16+
from executorch.exir import to_edge_transform_and_lower
17+
18+
from executorch.extension.export_util.utils import save_pte_program
1119

1220
from huggingface_hub import hf_hub_download
1321
from moshi.models import loaders
22+
from torch.ao.quantization.quantize_pt2e import (
23+
_convert_to_reference_decomposed_fx,
24+
convert_pt2e,
25+
prepare_pt2e,
26+
prepare_qat_pt2e,
27+
)
1428
from torch.export import export, ExportedProgram
1529

1630

@@ -131,6 +145,34 @@ def forward(self, x):
131145
ep_decode_output = exported_decode.module()(input)
132146
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
133147

148+
# PT2E Quantization
149+
quantizer = XNNPACKQuantizer()
150+
# 8 bit by default
151+
quantization_config = get_symmetric_quantization_config(
152+
is_per_channel=True,
153+
is_dynamic=True,
154+
)
155+
quantizer.set_global(quantization_config)
156+
m = exported_decode.module()
157+
m = prepare_pt2e(m, quantizer)
158+
m(input)
159+
m = convert_pt2e(m)
160+
print("quantized graph:")
161+
print(m.graph)
162+
# Export quantized module
163+
exported_decode: ExportedProgram = export(m, (input,), strict=False)
164+
165+
# Lower
166+
edge_manager = to_edge_transform_and_lower(
167+
exported_decode,
168+
partitioner=[XnnpackPartitioner()],
169+
)
170+
171+
exec_prog = edge_manager.to_executorch()
172+
print("exec graph:")
173+
print(exec_prog.exported_program().graph)
174+
assert len(exec_prog.exported_program().graph.nodes) > 1
175+
134176
def test_exported_encoding(self):
135177
"""Ensure exported encoding model is consistent with reference output."""
136178

0 commit comments

Comments
 (0)