4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ # pyre-strict
8
+
7
9
from math import prod
8
10
from typing import Optional , Tuple
9
11
10
12
import torch
11
- from executorch .exir .scalar_type import ScalarType
12
13
from torch .library import impl , Library
13
14
14
15
from .utils import get_conv1d_output_size , get_conv2d_output_size
@@ -74,8 +75,8 @@ def quantize_per_tensor_meta(
74
75
zero_point : int ,
75
76
quant_min : int ,
76
77
quant_max : int ,
77
- dtype : ScalarType ,
78
- ):
78
+ dtype : torch . dtype ,
79
+ ) -> torch . Tensor :
79
80
return input .new_empty (input .size (), dtype = dtype )
80
81
81
82
@@ -86,8 +87,8 @@ def dequantize_per_tensor_meta(
86
87
zero_point : int ,
87
88
quant_min : int ,
88
89
quant_max : int ,
89
- dtype : ScalarType ,
90
- ):
90
+ dtype : torch . dtype ,
91
+ ) -> torch . Tensor :
91
92
return input .new_empty (input .size (), dtype = torch .float )
92
93
93
94
@@ -102,7 +103,7 @@ def quantized_linear_meta(
102
103
out_shift : torch .Tensor ,
103
104
out_zero_point : int ,
104
105
offset : Optional [torch .Tensor ],
105
- ):
106
+ ) -> torch . Tensor :
106
107
# src comes in shape [leading_dims, in_dim]
107
108
# weight comes in shape [out_dim, in_dim]
108
109
# output comes in empty with shape [leading_dims, out_dim]
@@ -162,7 +163,7 @@ def quantized_layer_norm_meta(
162
163
eps : float ,
163
164
output_scale : float ,
164
165
output_zero_point : int ,
165
- ):
166
+ ) -> torch . Tensor :
166
167
return input .new_empty (input .size (), dtype = torch .uint8 )
167
168
168
169
@@ -173,7 +174,7 @@ def quantized_relu_meta(
173
174
out_zero_point : int ,
174
175
out_multiplier : torch .Tensor ,
175
176
out_shift : torch .Tensor ,
176
- ):
177
+ ) -> torch . Tensor :
177
178
return X .new_empty (X .size (), dtype = torch .uint8 )
178
179
179
180
0 commit comments