Skip to content

Qualcomm AI Engine Direct - Mimi Enablement Stage 1 #9570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .annotate_decomposed import AnnotateDecomposed
from .annotate_quant_attrs import AnnotateQuantAttrs
from .constant_i64_to_i32 import ConstantI64toI32
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_to_linear import ConvertToLinear
from .decompose_any import DecomposeAny
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_silu import DecomposeSilu
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
Expand All @@ -19,8 +27,9 @@
from .recompose_rms_norm import RecomposeRmsNorm
from .reduce_dynamic_range import ReduceDynamicRange
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_index_put_input import ReplaceIndexPutInput
from .replace_inf_buffer import ReplaceInfBuffer
from .replace_inf_values import ReplaceInfValues
from .tensor_i64_to_i32 import TensorI64toI32


Expand All @@ -29,10 +38,12 @@
AnnotateQuantAttrs,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
RecomposePReLU,
ConvertToLinear,
DecomposeAny,
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeSilu,
ExpandBroadcastTensorShape,
Expand All @@ -46,7 +57,8 @@
RecomposeRmsNorm,
ReduceDynamicRange,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfBuffer,
ReplaceInfValues,
TensorI64toI32,
]
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/annotate_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class AnnotateDecomposed(ExportPass):
generated after quantization process.
"""

decomp_ops = [torch.ops.aten.stack.default, torch.ops.aten.unbind.int]

def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateDecomposed, self).__init__()
self.edge_program = edge_program
Expand All @@ -32,7 +34,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()

def _annotate_stack(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.stack])
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
Expand Down
99 changes: 99 additions & 0 deletions backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class ConvertConv1dToConv2d(ExportPass):
"""
Conv1d is not supported by QNN.
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(ConvertConv1dToConv2d, self).__init__()
self.edge_program = edge_program

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
conv_op = exir_ops.edge.aten.convolution.default
for node in graph.nodes:
if node.target == conv_op and node.meta["val"].dim() == 3:

input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
unsqueeze_node = graph.create_node(
"call_function",
unsqueeze_op,
(
input_node,
2,
),
)
unsqueeze_node.meta = copy_meta(
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
with graph_module.graph.inserting_after(unsqueeze_node):

filter_node = node.args[1]
filter_node.meta["val"] = (
filter_node.meta["val"].unsqueeze(2).contiguous()
)
filter_tensor = get_parameter(filter_node, self.edge_program)
# Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
set_parameter(filter_tensor, filter_node, self.edge_program)

bias_node = node.args[2]
stride = [1] + node.args[3]
padding = [0] + node.args[4]
dilation = [1] + node.args[5]
transpose = node.args[6]
output_padding = [0] + node.args[7]
groups = node.args[8]

conv2d_node = graph.create_node(
"call_function",
conv_op,
(
unsqueeze_node,
filter_node,
bias_node,
stride,
padding,
dilation,
transpose,
output_padding,
groups,
),
)
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)

with graph_module.graph.inserting_after(conv2d_node):
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
(
conv2d_node,
[2],
),
)
squeeze_node.meta = copy_meta(node.meta)
for user in node.users.copy():
user.replace_input_with(node, squeeze_node)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
46 changes: 46 additions & 0 deletions backends/qualcomm/_passes/decompose_expm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class DecomposeExpM1(ExportPass):
"""
Decompose for expm1 to exponential and minus 1.
"""

def __init__(self, quantization_capture=False) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.special_expm1.default:
input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
exp_op = torch.ops.aten.exp.default
exp_node = graph.create_node("call_function", exp_op, (input_node,))
exp_node.meta = copy_meta(node.meta)
with graph_module.graph.inserting_after(exp_node):
sub_op = torch.ops.aten.sub.Tensor
sub_node = graph.create_node(
"call_function",
sub_op,
(
exp_node,
1,
),
)
sub_node.meta = copy_meta(node.meta)
for user in node.users.copy():
user.replace_input_with(node, sub_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
13 changes: 4 additions & 9 deletions backends/qualcomm/_passes/decompose_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class DecomposeSilu(ExportPass):
def __init__(self):
super(DecomposeSilu, self).__init__()

def _copy_meta(self, meta: Dict):
copied = {}
for k, v in meta.items():
copied[k] = v
return copied

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
Expand All @@ -34,14 +29,14 @@ def call(self, graph_module: torch.fx.GraphModule):
torch.ops.aten.sigmoid.default,
(silu_node_input,),
)
sigmoid_node.meta = self._copy_meta(silu_node.meta)
sigmoid_node.meta = copy_meta(silu_node.meta)
with graph_module.graph.inserting_after(sigmoid_node):
mul_node = graph.create_node(
"call_function",
torch.ops.aten.mul.Tensor,
(silu_node_input, sigmoid_node),
)
mul_node.meta = self._copy_meta(silu_node.meta)
mul_node.meta = copy_meta(silu_node.meta)
for user in silu_node.users.copy():
user.replace_input_with(silu_node, mul_node)

Expand Down
6 changes: 6 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Tensor,
Expand Down Expand Up @@ -87,10 +90,13 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.stack.default,
exir_ops.edge.aten.topk.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.unbind.int,
exir_ops.edge.aten.where.self,
_operator.getitem,
torch.ops.aten.scalar_tensor.default,
}

layout_type = {
Expand Down
34 changes: 20 additions & 14 deletions backends/qualcomm/_passes/lift_constant_scalar_operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,27 @@ class TensorConstant:
class TensorOpInfo:
target: torch._ops.OpOverload
use_schema_args: bool
use_self_dtype: bool


SCALAR_OPS = {
aten.eq.Scalar: TensorOpInfo(aten.eq.Tensor, False),
aten.ge.Scalar: TensorOpInfo(aten.ge.Tensor, False),
aten.gt.Scalar: TensorOpInfo(aten.gt.Tensor, False),
aten.le.Scalar: TensorOpInfo(aten.le.Tensor, False),
aten.lt.Scalar: TensorOpInfo(aten.lt.Tensor, False),
aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False),
aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False),
aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False),
aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False),
aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False),
aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False),
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False),
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False),
aten.eq.Scalar: TensorOpInfo(aten.eq.Tensor, False, False),
aten.ge.Scalar: TensorOpInfo(aten.ge.Tensor, False, False),
aten.gt.Scalar: TensorOpInfo(aten.gt.Tensor, False, False),
aten.le.Scalar: TensorOpInfo(aten.le.Tensor, False, False),
aten.lt.Scalar: TensorOpInfo(aten.lt.Tensor, False, False),
aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False, False),
aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False, False),
aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False, False),
aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False, False),
aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False, False),
aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False, False),
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False),
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True),
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
}


Expand All @@ -63,11 +66,14 @@ def __init__(self):
def _build_tensor_constant(
self, gm: torch.fx.GraphModule, node: fx.Node, const_val
) -> TensorConstant:
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
tensor = torch.tensor(
[const_val],
dtype=(
node.args[0].meta["val"].dtype
if not is_float_tensor(node)
and not SCALAR_OPS.get(node.target).use_self_dtype
else node.meta["val"].dtype
),
device=node.meta["val"].device,
Expand Down
48 changes: 48 additions & 0 deletions backends/qualcomm/_passes/replace_arange_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class ReplaceArangeArgs(ExportPass):
"""
During annotation, kwargs for arange will be removed due to restrictions by quantizer.
This causes arange to have no dtype, which means FP nodes might be inferred as INT nodes during calibration.
This can cause calibration to fail since QDQ can only be applied on FP nodes but not INT nodes.
To hint the dtype, we provide step size as 1.0 instead of 1, which makes the node a FP node.
"""

def __init__(self, quantization_capture=False) -> None:
super().__init__()
self.quantization_capture = quantization_capture

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.arange.default:
if torch.is_floating_point(node.meta["val"]) and len(node.args) == 1:
with graph_module.graph.inserting_after(node):
step_arange_op = torch.torch.ops.aten.arange.start_step
step_arange_node = graph.create_node(
"call_function",
step_arange_op,
(
0,
node.args[0],
1.0,
),
)
step_arange_node.meta = copy_meta(node.meta)

for user in node.users.copy():
user.replace_input_with(node, step_arange_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Loading
Loading