Skip to content

Commit 79a5edb

Browse files
mcr229facebook-github-bot
authored andcommitted
Make convert to linear an export pass
Summary: Let's remove the reliance on export pass from convert to linear pass. This allows us. We only use exported program in one place to find the bias. Differential Revision: D62266927
1 parent 41ec7fa commit 79a5edb

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

backends/xnnpack/passes/convert_to_linear.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from executorch.backends.transforms.addmm_mm_to_linear import (
1414
apply_addmm_mm_to_linear_transform,
1515
)
16-
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
1716
from executorch.backends.xnnpack.utils.utils import is_param_node
1817
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass
1919

2020
from torch.fx.passes.infra.pass_base import PassResult
2121
from torch.fx.passes.utils.source_matcher_utils import (
@@ -27,7 +27,7 @@
2727
logger.setLevel(logging.WARNING)
2828

2929

30-
class ConvertToLinearPass(XNNPACKPass):
30+
class ConvertToLinearPass(ExportPass):
3131
linear_modules = [
3232
torch.nn.Linear,
3333
torch.nn.functional.linear,
@@ -71,28 +71,25 @@ def get_arg(node: torch.fx.Node, arg: str):
7171
map_ = {"input": 0, "weight": 1}
7272
return None if arg == "bias" else node.args[map_[arg]]
7373

74-
def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node):
74+
def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node):
7575
"""
7676
For linear decomposed with mm + add, find bias in src partition
7777
"""
78-
out_channels = get_shape(weight)[0]
79-
bias = None
80-
81-
# Try to find bias node in all nodes
82-
for node in src_partition.nodes:
83-
if is_param_node(self.exported_program, node) and node != weight:
84-
bias = node
85-
86-
if bias is not None:
87-
assert get_shape(bias) == [
88-
out_channels
89-
], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}"
90-
else:
91-
assert exir_ops.edge.aten.add.Tensor not in [
92-
node.target for node in src_partition.nodes
93-
], f"Expecting to find bias for Linear module: {src_partition} but could not find it"
9478

95-
return bias
79+
mm_users = list(mm_node.users.keys())
80+
if len(mm_users) != 1:
81+
return None
82+
83+
add_node = mm_users[0]
84+
if add_node.target != exir_ops.edge.aten.add.Tensor:
85+
return None
86+
87+
for arg in add_node.all_input_nodes:
88+
if arg != mm_node and arg in src_partition.input_nodes:
89+
return arg
90+
91+
return None
92+
9693

9794
def create_linear(
9895
self,
@@ -119,7 +116,7 @@ def create_linear(
119116
src_partition.input_nodes + src_partition.params, # bias can be in params
120117
)
121118
if linear_bias is None and node.target == exir_ops.edge.aten.mm.default:
122-
linear_bias = self.find_bias_for_mm(src_partition, linear_weight)
119+
linear_bias = self.find_bias_for_mm(src_partition, node)
123120

124121
logger.debug(f"Found bias(?): {linear_bias} from node {node}")
125122

0 commit comments

Comments
 (0)