11
11
12
12
import torch
13
13
from executorch .exir .scalar_type import ScalarType
14
- from torch .library import impl , Library
14
+ from torch .library import Library , register_fake
15
15
16
16
from .utils import get_conv1d_output_size , get_conv2d_output_size
17
17
69
69
m = Library ("cadence" , "IMPL" , "Meta" )
70
70
71
71
72
- @impl ( m , " quantize_per_tensor" )
72
+ @register_fake ( "cadence:: quantize_per_tensor" )
73
73
def quantize_per_tensor_meta (
74
74
input : torch .Tensor ,
75
75
scale : float ,
@@ -81,7 +81,7 @@ def quantize_per_tensor_meta(
81
81
return input .new_empty (input .size (), dtype = dtype )
82
82
83
83
84
- @impl ( m , " dequantize_per_tensor" )
84
+ @register_fake ( "cadence:: dequantize_per_tensor" )
85
85
def dequantize_per_tensor_meta (
86
86
input : torch .Tensor ,
87
87
scale : float ,
@@ -93,7 +93,7 @@ def dequantize_per_tensor_meta(
93
93
return input .new_empty (input .size (), dtype = torch .float )
94
94
95
95
96
- @impl ( m , " quantized_linear" )
96
+ @register_fake ( "cadence:: quantized_linear" )
97
97
def quantized_linear_meta (
98
98
src : torch .Tensor ,
99
99
weight : torch .Tensor ,
@@ -115,7 +115,7 @@ def quantized_linear_meta(
115
115
return src .new_empty (out_size , dtype = torch .uint8 )
116
116
117
117
118
- @impl ( m , " quantized_conv" )
118
+ @register_fake ( "cadence:: quantized_conv" )
119
119
def quantized_conv_meta (
120
120
input : torch .Tensor ,
121
121
weight : torch .Tensor ,
@@ -153,7 +153,7 @@ def quantized_conv_meta(
153
153
return input .new_empty (output_size , dtype = input .dtype )
154
154
155
155
156
- @impl ( m , " quantized_layer_norm" )
156
+ @register_fake ( "cadence:: quantized_layer_norm" )
157
157
def quantized_layer_norm_meta (
158
158
input : torch .Tensor ,
159
159
X_scale : torch .Tensor ,
@@ -168,7 +168,7 @@ def quantized_layer_norm_meta(
168
168
return input .new_empty (input .size (), dtype = torch .uint8 )
169
169
170
170
171
- @impl ( m , " quantized_relu" )
171
+ @register_fake ( "cadence:: quantized_relu" )
172
172
def quantized_relu_meta (
173
173
X : torch .Tensor ,
174
174
X_zero_point : torch .Tensor ,
@@ -179,7 +179,7 @@ def quantized_relu_meta(
179
179
return X .new_empty (X .size (), dtype = torch .uint8 )
180
180
181
181
182
- @impl ( m , " quantized_matmul" )
182
+ @register_fake ( "cadence:: quantized_matmul" )
183
183
def quantized_matmul_meta (
184
184
X : torch .Tensor ,
185
185
X_zero_point : int ,
0 commit comments