10
10
from typing import Optional , Tuple
11
11
12
12
import torch
13
- from executorch .exir .scalar_type import ScalarType
14
- from torch .library import impl , Library
13
+ from torch .library import Library , register_fake
15
14
16
15
from .utils import get_conv1d_output_size , get_conv2d_output_size
17
16
69
68
m = Library ("cadence" , "IMPL" , "Meta" )
70
69
71
70
72
- @impl ( m , " quantize_per_tensor" )
71
+ @register_fake ( "cadence:: quantize_per_tensor" )
73
72
def quantize_per_tensor_meta (
74
73
input : torch .Tensor ,
75
74
scale : float ,
@@ -81,7 +80,7 @@ def quantize_per_tensor_meta(
81
80
return input .new_empty (input .size (), dtype = dtype )
82
81
83
82
84
- @impl ( m , " dequantize_per_tensor" )
83
+ @register_fake ( "cadence:: dequantize_per_tensor" )
85
84
def dequantize_per_tensor_meta (
86
85
input : torch .Tensor ,
87
86
scale : float ,
@@ -93,7 +92,7 @@ def dequantize_per_tensor_meta(
93
92
return input .new_empty (input .size (), dtype = torch .float )
94
93
95
94
96
- @impl ( m , " quantized_linear" )
95
+ @register_fake ( "cadence:: quantized_linear" )
97
96
def quantized_linear_meta (
98
97
src : torch .Tensor ,
99
98
weight : torch .Tensor ,
@@ -115,7 +114,7 @@ def quantized_linear_meta(
115
114
return src .new_empty (out_size , dtype = torch .uint8 )
116
115
117
116
118
- @impl ( m , " quantized_conv" )
117
+ @register_fake ( "cadence:: quantized_conv" )
119
118
def quantized_conv_meta (
120
119
input : torch .Tensor ,
121
120
weight : torch .Tensor ,
@@ -153,7 +152,7 @@ def quantized_conv_meta(
153
152
return input .new_empty (output_size , dtype = input .dtype )
154
153
155
154
156
- @impl ( m , " quantized_layer_norm" )
155
+ @register_fake ( "cadence:: quantized_layer_norm" )
157
156
def quantized_layer_norm_meta (
158
157
input : torch .Tensor ,
159
158
X_scale : torch .Tensor ,
@@ -168,7 +167,7 @@ def quantized_layer_norm_meta(
168
167
return input .new_empty (input .size (), dtype = torch .uint8 )
169
168
170
169
171
- @impl ( m , " quantized_relu" )
170
+ @register_fake ( "cadence:: quantized_relu" )
172
171
def quantized_relu_meta (
173
172
X : torch .Tensor ,
174
173
X_zero_point : torch .Tensor ,
@@ -179,7 +178,7 @@ def quantized_relu_meta(
179
178
return X .new_empty (X .size (), dtype = torch .uint8 )
180
179
181
180
182
- @impl ( m , " quantized_matmul" )
181
+ @register_fake ( "cadence:: quantized_matmul" )
183
182
def quantized_matmul_meta (
184
183
X : torch .Tensor ,
185
184
X_zero_point : int ,
0 commit comments