13
13
from executorch .backends .transforms .addmm_mm_to_linear import (
14
14
apply_addmm_mm_to_linear_transform ,
15
15
)
16
- from executorch .backends .xnnpack .passes .xnnpack_pass import XNNPACKPass
17
16
from executorch .backends .xnnpack .utils .utils import is_param_node
18
17
from executorch .exir .dialects ._ops import ops as exir_ops
18
+ from executorch .exir .pass_base import ExportPass
19
19
20
20
from torch .fx .passes .infra .pass_base import PassResult
21
21
from torch .fx .passes .utils .source_matcher_utils import (
27
27
logger .setLevel (logging .WARNING )
28
28
29
29
30
- class ConvertToLinearPass (XNNPACKPass ):
30
+ class ConvertToLinearPass (ExportPass ):
31
31
linear_modules = [
32
32
torch .nn .Linear ,
33
33
torch .nn .functional .linear ,
@@ -71,28 +71,25 @@ def get_arg(node: torch.fx.Node, arg: str):
71
71
map_ = {"input" : 0 , "weight" : 1 }
72
72
return None if arg == "bias" else node .args [map_ [arg ]]
73
73
74
- def find_bias_for_mm (self , src_partition : SourcePartition , weight : torch .fx .Node ):
74
+ def find_bias_for_mm (self , src_partition : SourcePartition , mm_node : torch .fx .Node ):
75
75
"""
76
76
For linear decomposed with mm + add, find bias in src partition
77
77
"""
78
- out_channels = get_shape (weight )[0 ]
79
- bias = None
80
-
81
- # Try to find bias node in all nodes
82
- for node in src_partition .nodes :
83
- if is_param_node (self .exported_program , node ) and node != weight :
84
- bias = node
85
-
86
- if bias is not None :
87
- assert get_shape (bias ) == [
88
- out_channels
89
- ], f"Expected bias shape { [out_channels ]} but got { get_shape (bias )} "
90
- else :
91
- assert exir_ops .edge .aten .add .Tensor not in [
92
- node .target for node in src_partition .nodes
93
- ], f"Expecting to find bias for Linear module: { src_partition } but could not find it"
94
78
95
- return bias
79
+ mm_users = list (mm_node .users .keys ())
80
+ if len (mm_users ) != 1 :
81
+ return None
82
+
83
+ add_node = mm_users [0 ]
84
+ if add_node .target != exir_ops .edge .aten .add .Tensor :
85
+ return None
86
+
87
+ for arg in add_node .all_input_nodes :
88
+ if arg != mm_node and arg in src_partition .input_nodes :
89
+ return arg
90
+
91
+ return None
92
+
96
93
97
94
def create_linear (
98
95
self ,
@@ -119,7 +116,7 @@ def create_linear(
119
116
src_partition .input_nodes + src_partition .params , # bias can be in params
120
117
)
121
118
if linear_bias is None and node .target == exir_ops .edge .aten .mm .default :
122
- linear_bias = self .find_bias_for_mm (src_partition , linear_weight )
119
+ linear_bias = self .find_bias_for_mm (src_partition , node )
123
120
124
121
logger .debug (f"Found bias(?): { linear_bias } from node { node } " )
125
122
0 commit comments