Skip to content

Commit b3f6d44

Browse files
yilifbfacebook-github-bot
authored andcommitted
Add a target rule for ops_registrations (#5083)
Summary: Pull Request resolved: #5083 Add a target rule for ops_registrations in the OSS repo to enable including this in fb repo Differential Revision: D62206605 Reviewed By: hsharma35
1 parent eca9ed5 commit b3f6d44

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

backends/cadence/aot/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ python_library(
6060
],
6161
)
6262

63+
python_library(
64+
name = "ops_registrations",
65+
srcs = [
66+
"ops_registrations.py",
67+
],
68+
deps = [
69+
":utils",
70+
"//caffe2:torch",
71+
"//executorch/exir:scalar_type",
72+
],
73+
)
74+
6375
export_file(name = "functions.yaml")
6476

6577
executorch_generated_lib(

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from math import prod
810
from typing import Optional, Tuple
911

1012
import torch
11-
from executorch.exir.scalar_type import ScalarType
1213
from torch.library import impl, Library
1314

1415
from .utils import get_conv1d_output_size, get_conv2d_output_size
@@ -74,8 +75,8 @@ def quantize_per_tensor_meta(
7475
zero_point: int,
7576
quant_min: int,
7677
quant_max: int,
77-
dtype: ScalarType,
78-
):
78+
dtype: torch.dtype,
79+
) -> torch.Tensor:
7980
return input.new_empty(input.size(), dtype=dtype)
8081

8182

@@ -86,8 +87,8 @@ def dequantize_per_tensor_meta(
8687
zero_point: int,
8788
quant_min: int,
8889
quant_max: int,
89-
dtype: ScalarType,
90-
):
90+
dtype: torch.dtype,
91+
) -> torch.Tensor:
9192
return input.new_empty(input.size(), dtype=torch.float)
9293

9394

@@ -102,7 +103,7 @@ def quantized_linear_meta(
102103
out_shift: torch.Tensor,
103104
out_zero_point: int,
104105
offset: Optional[torch.Tensor],
105-
):
106+
) -> torch.Tensor:
106107
# src comes in shape [leading_dims, in_dim]
107108
# weight comes in shape [out_dim, in_dim]
108109
# output comes in empty with shape [leading_dims, out_dim]
@@ -162,7 +163,7 @@ def quantized_layer_norm_meta(
162163
eps: float,
163164
output_scale: float,
164165
output_zero_point: int,
165-
):
166+
) -> torch.Tensor:
166167
return input.new_empty(input.size(), dtype=torch.uint8)
167168

168169

@@ -173,7 +174,7 @@ def quantized_relu_meta(
173174
out_zero_point: int,
174175
out_multiplier: torch.Tensor,
175176
out_shift: torch.Tensor,
176-
):
177+
) -> torch.Tensor:
177178
return X.new_empty(X.size(), dtype=torch.uint8)
178179

179180

0 commit comments

Comments
 (0)