6
6
7
7
# pyre-unsafe
8
8
9
- from typing import Callable , Dict
9
+ from typing import Callable , cast , Dict , Set
10
10
11
11
import torch
12
12
from executorch .backends .arm ._passes .arm_pass_utils import create_node
13
13
from executorch .backends .arm .tosa_quant_utils import QuantArgs
14
+ from executorch .backends .transforms .utils import delete_constant_placeholder
14
15
from executorch .exir import ExportedProgram
15
16
16
17
from executorch .exir .dialects ._ops import ops as exir_ops
17
18
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
18
19
19
20
from executorch .exir .pass_base import ExportPass , PassResult
20
21
from torch .fx import GraphModule
22
+ from torch .fx .node import Node
21
23
from torch .library import impl , Library
22
24
23
25
lib = Library ("tosa" , "DEF" )
@@ -29,6 +31,59 @@ def _table_impl(*args, **kwargs): # pyre-ignore
29
31
return args [0 ]
30
32
31
33
34
+ class TableOps :
35
+ """
36
+ Helper class for finding the corresponding table operator for a given Node.
37
+ """
38
+
39
+ def __init__ (self , exported_program : ExportedProgram ):
40
+ self .exported_program = exported_program
41
+
42
+ # Targets that follow a straigtforward one-to-one mapping to their table op
43
+ self .unary_table_ops : Dict [
44
+ EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]
45
+ ] = {
46
+ exir_ops .edge .aten .exp .default : torch .exp ,
47
+ exir_ops .edge .aten .floor .default : torch .floor ,
48
+ exir_ops .edge .aten .log .default : torch .log ,
49
+ exir_ops .edge .aten .reciprocal .default : torch .reciprocal ,
50
+ exir_ops .edge .aten .rsqrt .default : torch .rsqrt ,
51
+ exir_ops .edge .aten .sigmoid .default : torch .sigmoid ,
52
+ exir_ops .edge .aten .tanh .default : torch .tanh ,
53
+ exir_ops .edge .aten .hardsigmoid .default : torch .nn .functional .hardsigmoid ,
54
+ exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
55
+ }
56
+
57
+ # Targets that must be treated explicitly
58
+ self .special_table_ops : Set [EdgeOpOverload ] = {
59
+ exir_ops .edge .aten .pow .Tensor_Tensor ,
60
+ }
61
+
62
+ def __contains__ (self , node : Node ) -> bool :
63
+ return (
64
+ node .target in self .unary_table_ops or node .target in self .special_table_ops
65
+ )
66
+
67
+ def __getitem__ (self , node : Node ):
68
+ target = cast (EdgeOpOverload , node .target )
69
+ if target in self .unary_table_ops :
70
+ return self .unary_table_ops [target ]
71
+ elif target in self .special_table_ops :
72
+ match target :
73
+ case exir_ops .edge .aten .pow .Tensor_Tensor :
74
+ # Exponent is a constant. Retrieve it from the graph and embed it into a lambda.
75
+ exp_node = cast (Node , node .args [1 ])
76
+ exp_name = self .exported_program .graph_signature .inputs_to_buffers [
77
+ exp_node .name
78
+ ]
79
+ exp = self .exported_program .state_dict [exp_name ]
80
+ return lambda x : torch .pow (x , exp ).flatten ()
81
+ case _:
82
+ raise NotImplementedError ("Unhandled table operation" )
83
+ else :
84
+ raise KeyError ("Table op for {target} does not exist" )
85
+
86
+
32
87
class InsertTableOpsPass (ExportPass ):
33
88
"""
34
89
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
@@ -37,21 +92,10 @@ class InsertTableOpsPass(ExportPass):
37
92
which will be used to produce the table values in operators/op_table.py.
38
93
"""
39
94
40
- table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
41
- exir_ops .edge .aten .exp .default : torch .exp ,
42
- exir_ops .edge .aten .floor .default : torch .floor ,
43
- exir_ops .edge .aten .log .default : torch .log ,
44
- exir_ops .edge .aten .reciprocal .default : torch .reciprocal ,
45
- exir_ops .edge .aten .rsqrt .default : torch .rsqrt ,
46
- exir_ops .edge .aten .sigmoid .default : torch .sigmoid ,
47
- exir_ops .edge .aten .tanh .default : torch .tanh ,
48
- exir_ops .edge .aten .hardsigmoid .default : torch .nn .functional .hardsigmoid ,
49
- exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
50
- }
51
-
52
95
def __init__ (self , exported_program : ExportedProgram ) -> None :
53
96
super ().__init__ ()
54
97
self .exported_program = exported_program
98
+ self .table_ops = TableOps (exported_program )
55
99
56
100
def register_buffer (self , buffer_name : str , buffer : torch .Tensor ) -> None :
57
101
"""
@@ -86,7 +130,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
86
130
def call (self , graph_module : GraphModule ) -> PassResult :
87
131
modified = False
88
132
for node in graph_module .graph .nodes :
89
- if node .op != "call_function" or node . target not in self .table_ops :
133
+ if node .op != "call_function" or node not in self .table_ops :
90
134
continue
91
135
input_qparams = node .meta ["input_qparams" ]
92
136
output_qparams = node .meta ["output_qparams" ]
@@ -104,7 +148,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
104
148
assert len (output_qparams ) == 1
105
149
# Generate table buffer
106
150
buffer = self .generate_table_values (
107
- torch_op = self .table_ops [node . target ],
151
+ torch_op = self .table_ops [node ],
108
152
in_quantargs = input_qparams [0 ],
109
153
out_quantargs = output_qparams [0 ],
110
154
)
@@ -115,7 +159,19 @@ def call(self, graph_module: GraphModule) -> PassResult:
115
159
buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116
160
)
117
161
node .replace_all_uses_with (table_node )
118
- graph_module .graph .erase_node (node )
162
+
163
+ if node .target in self .table_ops .special_table_ops :
164
+ # The node must be treated explicitly
165
+ match node .target :
166
+ case exir_ops .edge .aten .pow .Tensor_Tensor :
167
+ exp_node = node .args [1 ]
168
+ graph_module .graph .erase_node (node )
169
+ delete_constant_placeholder (self .exported_program , exp_node )
170
+ case _:
171
+ raise NotImplementedError ("Unhandled table operation" )
172
+ else :
173
+ graph_module .graph .erase_node (node )
174
+
119
175
table_node .meta ["input_qparams" ] = input_qparams
120
176
table_node .meta ["output_qparams" ] = output_qparams
121
177
modified = True
0 commit comments