Skip to content

Commit a95808e

Browse files
Yi Lifacebook-github-bot
authored andcommitted
Update the API of registering fake kernels to new standard (#5190)
Summary: Pull Request resolved: #5190 Pull Request resolved: #5084 Update the decorator functions in registering operator abstract implementations/fake tensors to newer API Reviewed By: zonglinpeng, hsharma35 Differential Revision: D62206602
1 parent b3b9162 commit a95808e

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from typing import Optional, Tuple
1111

1212
import torch
13-
from executorch.exir.scalar_type import ScalarType
14-
from torch.library import impl, Library
13+
from torch.library import Library, register_fake
1514

1615
from .utils import get_conv1d_output_size, get_conv2d_output_size
1716

@@ -69,7 +68,7 @@
6968
m = Library("cadence", "IMPL", "Meta")
7069

7170

72-
@impl(m, "quantize_per_tensor")
71+
@register_fake("cadence::quantize_per_tensor")
7372
def quantize_per_tensor_meta(
7473
input: torch.Tensor,
7574
scale: float,
@@ -81,7 +80,7 @@ def quantize_per_tensor_meta(
8180
return input.new_empty(input.size(), dtype=dtype)
8281

8382

84-
@impl(m, "dequantize_per_tensor")
83+
@register_fake("cadence::dequantize_per_tensor")
8584
def dequantize_per_tensor_meta(
8685
input: torch.Tensor,
8786
scale: float,
@@ -93,7 +92,7 @@ def dequantize_per_tensor_meta(
9392
return input.new_empty(input.size(), dtype=torch.float)
9493

9594

96-
@impl(m, "quantized_linear")
95+
@register_fake("cadence::quantized_linear")
9796
def quantized_linear_meta(
9897
src: torch.Tensor,
9998
weight: torch.Tensor,
@@ -115,7 +114,7 @@ def quantized_linear_meta(
115114
return src.new_empty(out_size, dtype=torch.uint8)
116115

117116

118-
@impl(m, "quantized_conv")
117+
@register_fake("cadence::quantized_conv")
119118
def quantized_conv_meta(
120119
input: torch.Tensor,
121120
weight: torch.Tensor,
@@ -153,7 +152,7 @@ def quantized_conv_meta(
153152
return input.new_empty(output_size, dtype=input.dtype)
154153

155154

156-
@impl(m, "quantized_layer_norm")
155+
@register_fake("cadence::quantized_layer_norm")
157156
def quantized_layer_norm_meta(
158157
input: torch.Tensor,
159158
X_scale: torch.Tensor,
@@ -168,7 +167,7 @@ def quantized_layer_norm_meta(
168167
return input.new_empty(input.size(), dtype=torch.uint8)
169168

170169

171-
@impl(m, "quantized_relu")
170+
@register_fake("cadence::quantized_relu")
172171
def quantized_relu_meta(
173172
X: torch.Tensor,
174173
X_zero_point: torch.Tensor,
@@ -179,7 +178,7 @@ def quantized_relu_meta(
179178
return X.new_empty(X.size(), dtype=torch.uint8)
180179

181180

182-
@impl(m, "quantized_matmul")
181+
@register_fake("cadence::quantized_matmul")
183182
def quantized_matmul_meta(
184183
X: torch.Tensor,
185184
X_zero_point: int,

0 commit comments

Comments
 (0)