Skip to content

Qualcomm AI Engine Direct - Mimi Enablement Stage 2 #10098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
from .decompose_any import DecomposeAny
from .decompose_cdist import DecomposeCDist
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
Expand All @@ -27,6 +28,7 @@
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
from .recompose_rms_norm import RecomposeRmsNorm
from .reduce_dynamic_range import ReduceDynamicRange
from .remove_0d_tensor import Remove0DTensor
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_index_put_input import ReplaceIndexPutInput
Expand All @@ -40,8 +42,9 @@
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
DecomposeAny,
ConvertUpsampleBicubicWithBilinear,
DecomposeAny,
DecomposeCDist,
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
Expand All @@ -58,6 +61,7 @@
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/_passes/annotate_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ class AnnotateStack(ExportPass):
generated after quantization process.
"""

decomp_ops = [torch.ops.aten.unbind.int]
decomp_ops = [torch.ops.aten.stack.default]

def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateStack, self).__init__()
self.edge_program = edge_program

def _annotate_stack(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
partitions = get_source_partitions(
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/annotate_unbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
self.edge_program = edge_program

def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
partitions = get_source_partitions(
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
if src_partition.input_nodes[0].target in dq_ops:
Expand Down
10 changes: 10 additions & 0 deletions backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand Down Expand Up @@ -43,6 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule):
unsqueeze_node.meta = copy_meta(
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)

with graph_module.graph.inserting_after(unsqueeze_node):

filter_node = node.args[1]
Expand Down Expand Up @@ -92,6 +94,14 @@ def call(self, graph_module: torch.fx.GraphModule):
),
)
squeeze_node.meta = copy_meta(node.meta)

if QCOM_REQUANTIZE in input_node.meta:
input_node.meta.pop(QCOM_REQUANTIZE)
if QCOM_REQUANTIZE in node.meta:
squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[
QCOM_REQUANTIZE
]
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
for user in node.users.copy():
user.replace_input_with(node, squeeze_node)
graph.eliminate_dead_code()
Expand Down
81 changes: 81 additions & 0 deletions backends/qualcomm/_passes/decompose_cdist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class CDist(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
# Step 1: Compute differences
diff = x.unsqueeze(-2) - y.unsqueeze(-3)

# Step 2: Square differences
sq_diff = diff**2

# Step 3: Sum of squares
sum_sq_diff = sq_diff.sum(dim=-1)

# Step 4: Square root
distances = torch.sqrt(sum_sq_diff)

return distances


class DecomposeCDist(ExportPass):
"""
Decompose for math equivalent op.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
model = CDist()
if torch.ops.aten.cdist.default == node.target:
if len(node.args) > 2:
assert (
node.args[2] == 2
), "Currently only p=2 is supported for CDist Decomposition"
decomposed_module = torch.export.export(
model,
(node.args[0].meta["val"], node.args[1].meta["val"]),
strict=True,
).module()
with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0], "y": node.args[1]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ class TensorOpInfo:
}


SKIP_LIFT_OPS = {aten.full_like.default, aten.arange.start_step}
SKIP_LIFT_OPS = {
aten.full_like.default,
aten.arange.start_step,
aten.arange.default,
aten.scalar_tensor.default,
aten.elu.default,
}


class LiftConstantScalarOperands(ExportPass):
Expand Down
34 changes: 20 additions & 14 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ConvertConv1dToConv2d,
ConvertUpsampleBicubicWithBilinear,
DecomposeAny,
DecomposeCDist,
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
Expand All @@ -32,6 +33,7 @@
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
Expand Down Expand Up @@ -71,7 +73,7 @@ def get_capture_program_passes():
# If a pass is activated, it will be executed by default.
default_passes_and_setting = [
(AnnotateQuantAttrs, True),
(AnnotateStack, False),
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, True),
(ConvertConv1dToConv2d, True),
Expand All @@ -84,6 +86,7 @@ def get_capture_program_passes():
(LayoutTransform, True),
(RecomposePixelUnshuffle, True),
(RecomposeRmsNorm, False),
(Remove0DTensor, True),
(RemoveRedundancy, True),
(ReplaceIndexPutInput, True),
(TagQuantIO, False),
Expand Down Expand Up @@ -176,7 +179,23 @@ def transform_for_to_edge_pipeline(

return exported_program

# Before quantizer
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReduceDynamicRange())
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeEinsum())
self.add_pass(DecomposeExpM1())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(ReplaceInfValues())
self.add_pass(LiftConstantScalarOperands())
return self._transform(graph_module)

def transform_for_export_pipeline(self, exported_program: ExportedProgram):
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
Expand All @@ -191,16 +210,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
self.add_pass(FuseConsecutiveTranspose())
return self._transform(exported_program.graph_module)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReduceDynamicRange())
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeEinsum())
self.add_pass(DecomposeExpM1())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(ReplaceInfValues())
self.add_pass(LiftConstantScalarOperands())
return self._transform(graph_module)
36 changes: 36 additions & 0 deletions backends/qualcomm/_passes/remove_0d_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class Remove0DTensor(ExportPass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai,
Thanks for reviewing the PR.
I believe the pass you are suggesting is to change input from 0d to 1d tensor.
However, for our case, this 0D tensor happened during select op in the middle of graph. We just removed this select op since it does not affect the logic of the graph.
The exact point where 0D tensor occurs in mimi is under moshi/quantization/core_vq.py, where is tries to create a 1D tensor and retrieve index 0.
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I guess for this specific logic, it's more efficient to remove the select op.

A more generic way to handle 0-d tensor is to convert it to 1-d tensor. Like following

class Rank0ToRank1Pass(ExportPass):
    """
    Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs.
    """

    def __init__(self) -> None:
        super().__init__()

    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        for node in graph_module.graph.nodes:
            // we can make the pass configurable to decide what kind of node we want to reshape to 1-d tensor
            if node.op == "call_function" and node.meta["val"].shape == ():
                node.meta["val"] = node.meta["val"].reshape(1, 1)
        graph_module.recompile()
        return PassResult(graph_module, True)

"""
QNN does not allow 0D tensor, we remove the node that will output an 0D tensor.
Before adding operations to the list of nodes to be removed, please ensure that it will not change the logic.
"""

remove_ops = {
exir_ops.edge.aten.select.int,
exir_ops.edge.aten.select_copy.int,
}

def __init__(self, quantization_capture=False) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target in self.remove_ops and len(node.meta["val"].shape) == 0:
for user_n in list(node.users.keys()):
user_n.replace_input_with(node, node.args[0])
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
5 changes: 2 additions & 3 deletions backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
not_supported_operator,
to_be_implemented_operator,
)
from .utils import generate_qnn_executorch_option, get_skip_decomp_table
from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table


class QnnOperatorSupport(OperatorSupportBase):
Expand Down Expand Up @@ -181,5 +181,4 @@ def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
do_not_decompose = get_skip_decomp_table()

return do_not_decompose, None
return (do_not_decompose, filter_fn)
17 changes: 16 additions & 1 deletion backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ def generate_qnn_executorch_option(
return qnn_compile_spec_buffer


# Logic to determine whether to skip decompose and has higher priority than get_skip_decomp_table()
def filter_fn(node: torch.fx.Node) -> bool:
# QNN does not support int32/int64 IO for the following OPs.
potential_i32_i64_io_ops = [
torch.ops.aten.stack.default,
torch.ops.aten.unbind.int,
]
if node.target in potential_i32_i64_io_ops and node.meta["val"].dtype in [
torch.int32,
torch.int64,
]:
return False
return True


def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
do_not_decompose = [
torch.ops.aten.adaptive_avg_pool2d.default,
Expand All @@ -41,7 +56,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.stack.default,
torch.ops.aten.upsample_bicubic2d.vec,
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
# torch.ops.aten.unbind.int,
torch.ops.aten.unbind.int,
torch.ops.pt2e_quant.quantize_affine.default,
torch.ops.pt2e_quant.dequantize_affine.default,
]
Expand Down
Loading
Loading