Skip to content

Export Mimi to xnnpack #9303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/models/moshi/mimi/test_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@
import torch
import torch.nn as nn
import torchaudio
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.exir import to_edge_transform_and_lower

from huggingface_hub import hf_hub_download
from moshi.models import loaders
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)
from torch.export import export, ExportedProgram


Expand Down Expand Up @@ -131,6 +141,34 @@ def forward(self, x):
ep_decode_output = exported_decode.module()(input)
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))

# PT2E Quantization
quantizer = XNNPACKQuantizer()
# 8 bit by default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not lower bit? Also how did you decide between pt2e and ao?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One step for now. Will add 4 bit via 8da4w.

quantization_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=True,
)
quantizer.set_global(quantization_config)
m = exported_decode.module()
m = prepare_pt2e(m, quantizer)
m(input)
m = convert_pt2e(m)
print("quantized graph:")
print(m.graph)
# Export quantized module
exported_decode: ExportedProgram = export(m, (input,), strict=False)

# Lower
edge_manager = to_edge_transform_and_lower(
exported_decode,
partitioner=[XnnpackPartitioner()],
)

exec_prog = edge_manager.to_executorch()
print("exec graph:")
print(exec_prog.exported_program().graph)
assert len(exec_prog.exported_program().graph.nodes) > 1

def test_exported_encoding(self):
"""Ensure exported encoding model is consistent with reference output."""

Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def execute(self, custom_runner_cmd=None, method_index=0):
qnn_executor_runner_cmds = " ".join(
[
f"cd {self.workspace} &&",
f"chmod +x ./qnn_executor_runner &&",
"chmod +x ./qnn_executor_runner &&",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably because of the lint runner I forgot to fix...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I saw this on my prs and just bypassed it 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah lint runner failed on it.

f"./qnn_executor_runner {qnn_executor_runner_args}",
]
)
Expand Down
Loading