Skip to content

Qualcomm AI Engine Direct - ConvFormer Enablement #6654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 61 additions & 24 deletions backends/qualcomm/_passes/fuse_consecutive_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@

class FuseConsecutiveTranspose(ExportPass):
"""
This pass fuses consecutive transpose / permute into one to reduce runtime
overhead
This pass fuses consecutive transpose / permute into one or none to reduce runtime
overhead.
To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
by cloning transpose.
Example:
Before clone transpose:
relu -> permute1 ─> permute2
|──────> permute3

After clone transpose:
relu ─> permute1 ──────> permute2
|───> permute4(new) ─> permute3
"""

def __init__(self):
Expand All @@ -27,54 +37,81 @@ def __init__(self):
self.visited = set()
self.nodes = []

def _clone_transpose(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
graph = graph_module.graph
for n in graph_module.graph.nodes:
if n.target in self.op_map:
users = [user for user in list(n.users) if user.target in self.op_map]
if len(users) > 1:
for i in range(1, len(users)):
with graph.inserting_after(n):
clone_permute_node = graph.create_node(
"call_function",
exir_ops.edge.aten.permute_copy.default,
(n.args[0], n.args[1]),
)
clone_permute_node.meta = n.meta
users[i].replace_input_with(n, clone_permute_node)

def _is_dispensable(self, axis_order):
for index, value in enumerate(axis_order):
if index != value:
return False
return True

def _traverse(self, node):
if node in self.visited or node.target not in self.op_map:
return

self.nodes.append(node)
self.visited.add(node)
next_users = [n for n in list(node.users) if n.target in self.op_map]

assert (
len(next_users) <= 1
), "Each permute node should have at most 1 permute output node after _clone_transpose"
if not next_users:
return

if len(next_users) == 1:
self._traverse(list(node.users)[0])
else:
raise NotImplementedError(
f"Check the node {node}, wich encounter mutilple permute output case"
)
self._traverse(list(node.users)[0])

def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
graph = graph_module.graph
for n in graph_module.graph.nodes:
self._traverse(n)
if len(self.nodes) > 1:
permute_order = []
input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
input_shape = input_node.meta["val"].shape
axis_order = torch.arange(len(input_shape)).tolist()
for node in self.nodes:
permute_order.append(node.args[1])
axis_order = [axis_order[i] for i in node.args[1]]
with graph.inserting_after(input_node):
permute_op = exir_ops.edge.aten.permute_copy.default
permute_node = graph.create_node(
"call_function", permute_op, (input_node, axis_order)
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, permute_node)

# copy metadata
permute_node.meta = output_node.meta
# Without "qnn_permute", we might obtain wrong input shape
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
# If axis order is just [0,1,2,3], we ignore permute node
if self._is_dispensable(axis_order):
for user in output_node.users.copy():
user.replace_input_with(output_node, n.args[0])
else:
with graph.inserting_after(input_node):
permute_op = exir_ops.edge.aten.permute_copy.default
permute_node = graph.create_node(
"call_function", permute_op, (input_node, axis_order)
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, permute_node)

# copy metadata
permute_node.meta = output_node.meta
# Without "qnn_permute", we might obtain wrong input shape
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
permute_node.meta[QCOM_INSERTED_PERMUTE] = True

# clear current stack
self.nodes = []

def call(self, graph_module: torch.fx.GraphModule):
self._clone_transpose(graph_module)
self._fuse(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LayoutTransform(ExportPass):
"""

layout_sensitive_ops = {
exir_ops.edge.aten.adaptive_avg_pool2d.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import (
node_visitor,
op_abs,
op_adaptive_avg_pool2d,
op_add,
op_arange,
op_avg_pool2d,
Expand Down Expand Up @@ -78,6 +79,7 @@
__all__ = [
node_visitor,
op_abs,
op_adaptive_avg_pool2d,
op_add,
op_arange,
op_avg_pool2d,
Expand Down
125 changes: 125 additions & 0 deletions backends/qualcomm/builders/op_adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class AdaptiveAvgPool2D(NodeVisitor):
target = ["aten.adaptive_avg_pool2d.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:

input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

input_height = input_tensor.shape[1]
input_width = input_tensor.shape[2]

output_height = node.args[1][0]
output_width = node.args[1][1]

filter_height = input_height // output_height
filter_width = input_width // output_width
filter = [filter_height, filter_width]
filter_shape = [len(filter)]

stride_height = filter_height
stride_width = filter_width
stride = [stride_height, stride_width]
stride_shape = [len(stride)]

height = (output_height - 1) * stride_height + filter_height - input_height
width = (output_width - 1) * stride_width + filter_width - input_width
if height % 2 != 0 or width % 2 != 0:
warnings.warn(
"[QNN Delegate Op Builder]: Height or Width is not divisble by 2 with no remainder, fall back op",
stacklevel=1,
)
return

padding_height = height / 2
padding_width = width / 2
padding = [padding_height, padding_width]
padding_shape = [2, 2]

out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

adaptive_avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpPoolAvg2d.op_name,
)

adaptive_avg_pool2d_op.AddInputTensors([input_tensor_wrapper])
adaptive_avg_pool2d_op.AddOutputTensors([output_tensor_wrapper])

adaptive_avg_pool2d_op.AddTensorParam(
OpPoolAvg2d.param_filter_size,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(filter_shape),
filter_shape,
np.array(
filter,
dtype=np.uint32,
),
True,
)

adaptive_avg_pool2d_op.AddTensorParam(
OpPoolAvg2d.param_stride,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(stride_shape),
stride_shape,
np.array(
stride,
dtype=np.uint32,
),
True,
)

adaptive_avg_pool2d_op.AddTensorParam(
OpPoolAvg2d.param_pad_amount,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(padding_shape),
padding_shape,
np.array(
[[padding[0], padding[0]], [padding[1], padding[1]]],
dtype=np.uint32,
),
True,
)

return adaptive_avg_pool2d_op
24 changes: 13 additions & 11 deletions backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ def define_node(
nodes_to_wrappers,
)

layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]

bias_node = node.args[3]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
)
if bias_node is not None:
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
)
layer_norm_input_tensors.append(bias_tensor_wrapper)

epsilon = node.args[4]

Expand All @@ -89,9 +93,7 @@ def define_node(
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpLayerNorm.op_name,
)
layer_norm_op.AddInputTensors(
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
)
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
layer_norm_op.AddScalarParam(
OpLayerNorm.param_epsilon,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def define_node(
nodes_to_wrappers,
)

# Fake node, nn moudle seems to be inconsistant with document
# Fake node, nn module seems to be inconsistant with document
bias_tensor = torch.zeros(weight_tensor.shape)
bias_node = torch.fx.Node(
node.graph,
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.square.default])
def annotate_square(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.gelu.default])
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
22 changes: 20 additions & 2 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def forward(self, x):
return torch.abs(x)


class AdaptiveAvgPool2D(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
return adaptive_avg_pool(x)


class Add(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -685,15 +694,24 @@ def forward(self, x):


class LayerNorm(torch.nn.Module):
def __init__(self):
def __init__(self, bias=True):
super().__init__()
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
self.linear = torch.nn.Linear(768, 196)

def forward(self, x):
return self.linear(self.layer_norm(x))


class LayerNormAdd(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer_norm = torch.nn.LayerNorm([512], eps=1e-6, bias=False)

def forward(self, x, y):
return self.layer_norm(x) + y


class LeakyReLUDefault(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading