Skip to content

Commit 126abb5

Browse files
authored
Update the API of registering fake kernels to new standard (#5084)
Differential Revision: D62206602 Pull Request resolved: #5190
1 parent 1eeded1 commit 126abb5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional, Tuple
1111

1212
import torch
13-
from torch.library import impl, Library
13+
from torch.library import Library, register_fake
1414

1515
from .utils import get_conv1d_output_size, get_conv2d_output_size
1616

@@ -68,7 +68,7 @@
6868
m = Library("cadence", "IMPL", "Meta")
6969

7070

71-
@impl(m, "quantize_per_tensor")
71+
@register_fake("cadence::quantize_per_tensor")
7272
def quantize_per_tensor_meta(
7373
input: torch.Tensor,
7474
scale: float,
@@ -80,7 +80,7 @@ def quantize_per_tensor_meta(
8080
return input.new_empty(input.size(), dtype=dtype)
8181

8282

83-
@impl(m, "dequantize_per_tensor")
83+
@register_fake("cadence::dequantize_per_tensor")
8484
def dequantize_per_tensor_meta(
8585
input: torch.Tensor,
8686
scale: float,
@@ -92,7 +92,7 @@ def dequantize_per_tensor_meta(
9292
return input.new_empty(input.size(), dtype=torch.float)
9393

9494

95-
@impl(m, "quantized_linear")
95+
@register_fake("cadence::quantized_linear")
9696
def quantized_linear_meta(
9797
src: torch.Tensor,
9898
weight: torch.Tensor,
@@ -114,7 +114,7 @@ def quantized_linear_meta(
114114
return src.new_empty(out_size, dtype=torch.uint8)
115115

116116

117-
@impl(m, "quantized_conv")
117+
@register_fake("cadence::quantized_conv")
118118
def quantized_conv_meta(
119119
input: torch.Tensor,
120120
weight: torch.Tensor,
@@ -152,7 +152,7 @@ def quantized_conv_meta(
152152
return input.new_empty(output_size, dtype=input.dtype)
153153

154154

155-
@impl(m, "quantized_layer_norm")
155+
@register_fake("cadence::quantized_layer_norm")
156156
def quantized_layer_norm_meta(
157157
input: torch.Tensor,
158158
X_scale: torch.Tensor,
@@ -167,7 +167,7 @@ def quantized_layer_norm_meta(
167167
return input.new_empty(input.size(), dtype=torch.uint8)
168168

169169

170-
@impl(m, "quantized_relu")
170+
@register_fake("cadence::quantized_relu")
171171
def quantized_relu_meta(
172172
X: torch.Tensor,
173173
X_zero_point: torch.Tensor,
@@ -178,7 +178,7 @@ def quantized_relu_meta(
178178
return X.new_empty(X.size(), dtype=torch.uint8)
179179

180180

181-
@impl(m, "quantized_matmul")
181+
@register_fake("cadence::quantized_matmul")
182182
def quantized_matmul_meta(
183183
X: torch.Tensor,
184184
X_zero_point: int,

0 commit comments

Comments
 (0)