13
13
from executorch .backends .transforms .addmm_mm_to_linear import (
14
14
apply_addmm_mm_to_linear_transform ,
15
15
)
16
- from executorch .backends .xnnpack .passes .xnnpack_pass import XNNPACKPass
17
- from executorch .backends .xnnpack .utils .utils import is_param_node
18
16
from executorch .exir .dialects ._ops import ops as exir_ops
17
+ from executorch .exir .pass_base import ExportPass
19
18
20
19
from torch .fx .passes .infra .pass_base import PassResult
21
20
from torch .fx .passes .utils .source_matcher_utils import (
27
26
logger .setLevel (logging .WARNING )
28
27
29
28
30
- class ConvertToLinearPass (XNNPACKPass ):
29
+ class ConvertToLinearPass (ExportPass ):
31
30
linear_modules = [
32
31
torch .nn .Linear ,
33
32
torch .nn .functional .linear ,
@@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str):
71
70
map_ = {"input" : 0 , "weight" : 1 }
72
71
return None if arg == "bias" else node .args [map_ [arg ]]
73
72
74
- def find_bias_for_mm (self , src_partition : SourcePartition , weight : torch .fx .Node ):
73
+ def find_bias_for_mm (self , src_partition : SourcePartition , mm_node : torch .fx .Node ):
75
74
"""
76
75
For linear decomposed with mm + add, find bias in src partition
77
76
"""
78
- out_channels = get_shape (weight )[0 ]
79
- bias = None
80
-
81
- # Try to find bias node in all nodes
82
- for node in src_partition .nodes :
83
- if is_param_node (self .exported_program , node ) and node != weight :
84
- bias = node
85
-
86
- if bias is not None :
87
- assert get_shape (bias ) == [
88
- out_channels
89
- ], f"Expected bias shape { [out_channels ]} but got { get_shape (bias )} "
90
- else :
91
- assert exir_ops .edge .aten .add .Tensor not in [
92
- node .target for node in src_partition .nodes
93
- ], f"Expecting to find bias for Linear module: { src_partition } but could not find it"
94
77
95
- return bias
78
+ mm_users = list (mm_node .users .keys ())
79
+ if len (mm_users ) != 1 :
80
+ return None
81
+
82
+ add_node = mm_users [0 ]
83
+ if add_node .target != exir_ops .edge .aten .add .Tensor :
84
+ return None
85
+
86
+ for arg in add_node .all_input_nodes :
87
+ if arg != mm_node and arg in src_partition .input_nodes :
88
+ return arg
89
+
90
+ return None
96
91
97
92
def create_linear (
98
93
self ,
@@ -119,7 +114,7 @@ def create_linear(
119
114
src_partition .input_nodes + src_partition .params , # bias can be in params
120
115
)
121
116
if linear_bias is None and node .target == exir_ops .edge .aten .mm .default :
122
- linear_bias = self .find_bias_for_mm (src_partition , linear_weight )
117
+ linear_bias = self .find_bias_for_mm (src_partition , node )
123
118
124
119
logger .debug (f"Found bias(?): { linear_bias } from node { node } " )
125
120
0 commit comments