Skip to content

Commit fa99b25

Browse files
author
Martin Yuan
committed
Export Mimi to xnnpack
1 parent fc6d86e commit fa99b25

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@
88
import torch
99
import torch.nn as nn
1010
import torchaudio
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
12+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
13+
get_symmetric_quantization_config,
14+
XNNPACKQuantizer,
15+
)
16+
from executorch.exir import to_edge_transform_and_lower
1117

1218
from huggingface_hub import hf_hub_download
1319
from moshi.models import loaders
20+
from torch.ao.quantization.quantize_pt2e import (
21+
convert_pt2e,
22+
prepare_pt2e,
23+
)
1424
from torch.export import export, ExportedProgram
1525

1626

@@ -131,6 +141,34 @@ def forward(self, x):
131141
ep_decode_output = exported_decode.module()(input)
132142
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
133143

144+
# PT2E Quantization
145+
quantizer = XNNPACKQuantizer()
146+
# 8 bit by default
147+
quantization_config = get_symmetric_quantization_config(
148+
is_per_channel=True,
149+
is_dynamic=True,
150+
)
151+
quantizer.set_global(quantization_config)
152+
m = exported_decode.module()
153+
m = prepare_pt2e(m, quantizer)
154+
m(input)
155+
m = convert_pt2e(m)
156+
print("quantized graph:")
157+
print(m.graph)
158+
# Export quantized module
159+
exported_decode: ExportedProgram = export(m, (input,), strict=False)
160+
161+
# Lower
162+
edge_manager = to_edge_transform_and_lower(
163+
exported_decode,
164+
partitioner=[XnnpackPartitioner()],
165+
)
166+
167+
exec_prog = edge_manager.to_executorch()
168+
print("exec graph:")
169+
print(exec_prog.exported_program().graph)
170+
assert len(exec_prog.exported_program().graph.nodes) > 1
171+
134172
def test_exported_encoding(self):
135173
"""Ensure exported encoding model is consistent with reference output."""
136174

examples/qualcomm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def execute(self, custom_runner_cmd=None, method_index=0):
205205
qnn_executor_runner_cmds = " ".join(
206206
[
207207
f"cd {self.workspace} &&",
208-
f"chmod +x ./qnn_executor_runner &&",
208+
"chmod +x ./qnn_executor_runner &&",
209209
f"./qnn_executor_runner {qnn_executor_runner_args}",
210210
]
211211
)

0 commit comments

Comments
 (0)