10
10
from typing import Optional , Tuple
11
11
12
12
import torch
13
- from torch .library import impl , Library
13
+ from torch .library import Library , register_fake
14
14
15
15
from .utils import get_conv1d_output_size , get_conv2d_output_size
16
16
68
68
m = Library ("cadence" , "IMPL" , "Meta" )
69
69
70
70
71
- @impl ( m , " quantize_per_tensor" )
71
+ @register_fake ( "cadence:: quantize_per_tensor" )
72
72
def quantize_per_tensor_meta (
73
73
input : torch .Tensor ,
74
74
scale : float ,
@@ -80,7 +80,7 @@ def quantize_per_tensor_meta(
80
80
return input .new_empty (input .size (), dtype = dtype )
81
81
82
82
83
- @impl ( m , " dequantize_per_tensor" )
83
+ @register_fake ( "cadence:: dequantize_per_tensor" )
84
84
def dequantize_per_tensor_meta (
85
85
input : torch .Tensor ,
86
86
scale : float ,
@@ -92,7 +92,7 @@ def dequantize_per_tensor_meta(
92
92
return input .new_empty (input .size (), dtype = torch .float )
93
93
94
94
95
- @impl ( m , " quantized_linear" )
95
+ @register_fake ( "cadence:: quantized_linear" )
96
96
def quantized_linear_meta (
97
97
src : torch .Tensor ,
98
98
weight : torch .Tensor ,
@@ -114,7 +114,7 @@ def quantized_linear_meta(
114
114
return src .new_empty (out_size , dtype = torch .uint8 )
115
115
116
116
117
- @impl ( m , " quantized_conv" )
117
+ @register_fake ( "cadence:: quantized_conv" )
118
118
def quantized_conv_meta (
119
119
input : torch .Tensor ,
120
120
weight : torch .Tensor ,
@@ -152,7 +152,7 @@ def quantized_conv_meta(
152
152
return input .new_empty (output_size , dtype = input .dtype )
153
153
154
154
155
- @impl ( m , " quantized_layer_norm" )
155
+ @register_fake ( "cadence:: quantized_layer_norm" )
156
156
def quantized_layer_norm_meta (
157
157
input : torch .Tensor ,
158
158
X_scale : torch .Tensor ,
@@ -167,7 +167,7 @@ def quantized_layer_norm_meta(
167
167
return input .new_empty (input .size (), dtype = torch .uint8 )
168
168
169
169
170
- @impl ( m , " quantized_relu" )
170
+ @register_fake ( "cadence:: quantized_relu" )
171
171
def quantized_relu_meta (
172
172
X : torch .Tensor ,
173
173
X_zero_point : torch .Tensor ,
@@ -178,7 +178,7 @@ def quantized_relu_meta(
178
178
return X .new_empty (X .size (), dtype = torch .uint8 )
179
179
180
180
181
- @impl ( m , " quantized_matmul" )
181
+ @register_fake ( "cadence:: quantized_matmul" )
182
182
def quantized_matmul_meta (
183
183
X : torch .Tensor ,
184
184
X_zero_point : int ,
0 commit comments