Skip to content

Commit cfda47d

Browse files
Yi Lifacebook-github-bot
authored andcommitted
Update the API of registering fake kernels to new standard (#5190)
Summary: Pull Request resolved: #5190 Pull Request resolved: #5084 Update the decorator functions in registering operator abstract implementations/fake tensors to newer API Reviewed By: zonglinpeng, hsharma35 Differential Revision: D62206602
1 parent b3f6d44 commit cfda47d

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

backends/cadence/aot/TARGETS

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ python_library(
6666
"ops_registrations.py",
6767
],
6868
deps = [
69-
":utils",
70-
"//caffe2:torch",
71-
"//executorch/exir:scalar_type",
69+
"fbcode//caffe2:torch",
70+
"fbcode//executorch/backends/cadence/aot:utils",
7271
],
7372
)
7473

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional, Tuple
1111

1212
import torch
13-
from torch.library import impl, Library
13+
from torch.library import Library, register_fake
1414

1515
from .utils import get_conv1d_output_size, get_conv2d_output_size
1616

@@ -68,7 +68,7 @@
6868
m = Library("cadence", "IMPL", "Meta")
6969

7070

71-
@impl(m, "quantize_per_tensor")
71+
@register_fake("cadence::quantize_per_tensor")
7272
def quantize_per_tensor_meta(
7373
input: torch.Tensor,
7474
scale: float,
@@ -80,7 +80,7 @@ def quantize_per_tensor_meta(
8080
return input.new_empty(input.size(), dtype=dtype)
8181

8282

83-
@impl(m, "dequantize_per_tensor")
83+
@register_fake("cadence::dequantize_per_tensor")
8484
def dequantize_per_tensor_meta(
8585
input: torch.Tensor,
8686
scale: float,
@@ -92,7 +92,7 @@ def dequantize_per_tensor_meta(
9292
return input.new_empty(input.size(), dtype=torch.float)
9393

9494

95-
@impl(m, "quantized_linear")
95+
@register_fake("cadence::quantized_linear")
9696
def quantized_linear_meta(
9797
src: torch.Tensor,
9898
weight: torch.Tensor,
@@ -114,7 +114,7 @@ def quantized_linear_meta(
114114
return src.new_empty(out_size, dtype=torch.uint8)
115115

116116

117-
@impl(m, "quantized_conv")
117+
@register_fake("cadence::quantized_conv")
118118
def quantized_conv_meta(
119119
input: torch.Tensor,
120120
weight: torch.Tensor,
@@ -152,7 +152,7 @@ def quantized_conv_meta(
152152
return input.new_empty(output_size, dtype=input.dtype)
153153

154154

155-
@impl(m, "quantized_layer_norm")
155+
@register_fake("cadence::quantized_layer_norm")
156156
def quantized_layer_norm_meta(
157157
input: torch.Tensor,
158158
X_scale: torch.Tensor,
@@ -167,7 +167,7 @@ def quantized_layer_norm_meta(
167167
return input.new_empty(input.size(), dtype=torch.uint8)
168168

169169

170-
@impl(m, "quantized_relu")
170+
@register_fake("cadence::quantized_relu")
171171
def quantized_relu_meta(
172172
X: torch.Tensor,
173173
X_zero_point: torch.Tensor,
@@ -178,7 +178,7 @@ def quantized_relu_meta(
178178
return X.new_empty(X.size(), dtype=torch.uint8)
179179

180180

181-
@impl(m, "quantized_matmul")
181+
@register_fake("cadence::quantized_matmul")
182182
def quantized_matmul_meta(
183183
X: torch.Tensor,
184184
X_zero_point: int,

0 commit comments

Comments
 (0)