Skip to content

Commit cd9d536

Browse files
authored
Make convert to linear an export pass
Differential Revision: D62266927 Pull Request resolved: #5133
1 parent b52d4b6 commit cd9d536

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

backends/xnnpack/passes/convert_to_linear.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from executorch.backends.transforms.addmm_mm_to_linear import (
1414
apply_addmm_mm_to_linear_transform,
1515
)
16-
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
17-
from executorch.backends.xnnpack.utils.utils import is_param_node
1816
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.pass_base import ExportPass
1918

2019
from torch.fx.passes.infra.pass_base import PassResult
2120
from torch.fx.passes.utils.source_matcher_utils import (
@@ -27,7 +26,7 @@
2726
logger.setLevel(logging.WARNING)
2827

2928

30-
class ConvertToLinearPass(XNNPACKPass):
29+
class ConvertToLinearPass(ExportPass):
3130
linear_modules = [
3231
torch.nn.Linear,
3332
torch.nn.functional.linear,
@@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str):
7170
map_ = {"input": 0, "weight": 1}
7271
return None if arg == "bias" else node.args[map_[arg]]
7372

74-
def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node):
73+
def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node):
7574
"""
7675
For linear decomposed with mm + add, find bias in src partition
7776
"""
78-
out_channels = get_shape(weight)[0]
79-
bias = None
80-
81-
# Try to find bias node in all nodes
82-
for node in src_partition.nodes:
83-
if is_param_node(self.exported_program, node) and node != weight:
84-
bias = node
85-
86-
if bias is not None:
87-
assert get_shape(bias) == [
88-
out_channels
89-
], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}"
90-
else:
91-
assert exir_ops.edge.aten.add.Tensor not in [
92-
node.target for node in src_partition.nodes
93-
], f"Expecting to find bias for Linear module: {src_partition} but could not find it"
9477

95-
return bias
78+
mm_users = list(mm_node.users.keys())
79+
if len(mm_users) != 1:
80+
return None
81+
82+
add_node = mm_users[0]
83+
if add_node.target != exir_ops.edge.aten.add.Tensor:
84+
return None
85+
86+
for arg in add_node.all_input_nodes:
87+
if arg != mm_node and arg in src_partition.input_nodes:
88+
return arg
89+
90+
return None
9691

9792
def create_linear(
9893
self,
@@ -119,7 +114,7 @@ def create_linear(
119114
src_partition.input_nodes + src_partition.params, # bias can be in params
120115
)
121116
if linear_bias is None and node.target == exir_ops.edge.aten.mm.default:
122-
linear_bias = self.find_bias_for_mm(src_partition, linear_weight)
117+
linear_bias = self.find_bias_for_mm(src_partition, node)
123118

124119
logger.debug(f"Found bias(?): {linear_bias} from node {node}")
125120

0 commit comments

Comments
 (0)