Skip to content

Make convert to linear an export pass #5133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions backends/xnnpack/passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from executorch.backends.transforms.addmm_mm_to_linear import (
apply_addmm_mm_to_linear_transform,
)
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.source_matcher_utils import (
Expand All @@ -27,7 +26,7 @@
logger.setLevel(logging.WARNING)


class ConvertToLinearPass(XNNPACKPass):
class ConvertToLinearPass(ExportPass):
linear_modules = [
torch.nn.Linear,
torch.nn.functional.linear,
Expand Down Expand Up @@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str):
map_ = {"input": 0, "weight": 1}
return None if arg == "bias" else node.args[map_[arg]]

def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node):
def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node):
"""
For linear decomposed with mm + add, find bias in src partition
"""
out_channels = get_shape(weight)[0]
bias = None

# Try to find bias node in all nodes
for node in src_partition.nodes:
if is_param_node(self.exported_program, node) and node != weight:
bias = node

if bias is not None:
assert get_shape(bias) == [
out_channels
], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}"
else:
assert exir_ops.edge.aten.add.Tensor not in [
node.target for node in src_partition.nodes
], f"Expecting to find bias for Linear module: {src_partition} but could not find it"

return bias
mm_users = list(mm_node.users.keys())
if len(mm_users) != 1:
return None

add_node = mm_users[0]
if add_node.target != exir_ops.edge.aten.add.Tensor:
return None

for arg in add_node.all_input_nodes:
if arg != mm_node and arg in src_partition.input_nodes:
return arg

return None

def create_linear(
self,
Expand All @@ -119,7 +114,7 @@ def create_linear(
src_partition.input_nodes + src_partition.params, # bias can be in params
)
if linear_bias is None and node.target == exir_ops.edge.aten.mm.default:
linear_bias = self.find_bias_for_mm(src_partition, linear_weight)
linear_bias = self.find_bias_for_mm(src_partition, node)

logger.debug(f"Found bias(?): {linear_bias} from node {node}")

Expand Down
Loading