4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Any , Dict , Tuple
8
+
7
9
import torch
8
10
from executorch .backends .cadence .aot .utils import get_edge_overload_packet
9
11
from executorch .exir .dialects ._ops import ops as exir_ops
10
- from executorch .exir .pass_base import ExportPass , ProxyValue
12
+ from executorch .exir .pass_base import ExportPass , NodeMetadata , ProxyValue
11
13
from torch ._subclasses import FakeTensor
12
14
from torch .utils ._pytree import tree_map_only
13
15
14
16
17
+ # pyre-strict
18
+
19
+ # Similar to what's done in executorch/exir/pass_base.py
20
+ Argument = Any # pyre-ignore
21
+
22
+
15
23
class ReplacePT2QuantWithCadenceQuantPass (ExportPass ):
16
24
"""
17
25
Replace the pt2 quantization ops with custom cadence quantization ops.
18
26
"""
19
27
20
- def call_operator (self , op , args , kwargs , meta ):
28
+ def call_operator (
29
+ self ,
30
+ op , # pyre-ignore
31
+ args : Tuple [Argument , ...],
32
+ kwargs : Dict [str , Argument ],
33
+ meta : NodeMetadata ,
34
+ ) -> ProxyValue :
21
35
if op not in {exir_ops .edge .quantized_decomposed .quantize_per_tensor .default }:
22
36
return super ().call_operator (op , args , kwargs , meta )
23
37
@@ -34,7 +48,13 @@ class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
34
48
Replace the pt2 dequantization ops with custom cadence dequantization ops.
35
49
"""
36
50
37
- def call_operator (self , op , args , kwargs , meta ):
51
+ def call_operator (
52
+ self ,
53
+ op , # pyre-ignore
54
+ args : Tuple [Argument , ...],
55
+ kwargs : Dict [str , Argument ],
56
+ meta : NodeMetadata ,
57
+ ) -> ProxyValue :
38
58
if op not in {exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default }:
39
59
return super ().call_operator (op , args , kwargs , meta )
40
60
@@ -51,7 +71,13 @@ class ReplaceScalarTensorWithFullPass(ExportPass):
51
71
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
52
72
"""
53
73
54
- def call_operator (self , op , args , kwargs , meta ):
74
+ def call_operator (
75
+ self ,
76
+ op , # pyre-ignore
77
+ args : Tuple [Argument , ...],
78
+ kwargs : Dict [str , Argument ],
79
+ meta : NodeMetadata ,
80
+ ) -> ProxyValue :
55
81
if op not in {
56
82
exir_ops .edge .aten .scalar_tensor .default ,
57
83
torch .ops .aten .scalar_tensor .default ,
@@ -64,7 +90,7 @@ def call_operator(self, op, args, kwargs, meta):
64
90
[1 ],
65
91
args [0 ],
66
92
),
67
- {},
93
+ {"dtype" : torch . float32 },
68
94
meta ,
69
95
)
70
96
@@ -75,7 +101,13 @@ class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
75
101
view_copy op
76
102
"""
77
103
78
- def call_operator (self , op , args , kwargs , meta ):
104
+ def call_operator (
105
+ self ,
106
+ op , # pyre-ignore
107
+ args : Tuple [Argument , ...],
108
+ kwargs : Dict [str , Argument ],
109
+ meta : NodeMetadata ,
110
+ ) -> ProxyValue :
79
111
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
80
112
# which allows us to cover all overloads.
81
113
if get_edge_overload_packet (op ) not in {
@@ -99,7 +131,13 @@ def call_operator(self, op, args, kwargs, meta):
99
131
100
132
101
133
class RemoveZeroSizedCatArgsPass (ExportPass ):
102
- def call_operator (self , op , args , kwargs , meta ):
134
+ def call_operator (
135
+ self ,
136
+ op , # pyre-ignore
137
+ args : Tuple [Argument , ...],
138
+ kwargs : Dict [str , Argument ],
139
+ meta : NodeMetadata ,
140
+ ) -> ProxyValue :
103
141
if op != exir_ops .edge .aten .cat .default :
104
142
return super ().call_operator (op , args , kwargs , meta )
105
143
@@ -122,6 +160,7 @@ def call_operator(self, op, args, kwargs, meta):
122
160
# TODO(matthiascremon): confirm this is the best way to do this.
123
161
if isinstance (result , FakeTensor ):
124
162
result .constant = result
163
+ # pyre-ignore[7]: Incompatible return type.
125
164
return torch .empty_like (result )
126
165
127
166
# If there was only one tensor in the new_args list,
@@ -130,7 +169,7 @@ def call_operator(self, op, args, kwargs, meta):
130
169
return new_args [0 ]
131
170
132
171
# Otherwise, we replace args[0] with new_args.
133
- args = list (args )
134
- args [0 ] = new_args
172
+ init_args = list (args )
173
+ init_args [0 ] = new_args
135
174
args = tuple (args )
136
175
return super ().call_operator (op , args , kwargs , meta )
0 commit comments