Skip to content

Commit b3b9162

Browse files
yilifbfacebook-github-bot
authored andcommitted
Add a target rule for ops_registrations (#5083)
Summary: Pull Request resolved: #5083 Add a target rule for ops_registrations in the OSS repo to enable including this in fb repo Differential Revision: D62206605 Reviewed By: hsharma35
1 parent 6b1e328 commit b3b9162

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

backends/cadence/aot/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ python_library(
6060
],
6161
)
6262

63+
python_library(
64+
name = "ops_registrations",
65+
srcs = [
66+
"ops_registrations.py",
67+
],
68+
deps = [
69+
":utils",
70+
"//caffe2:torch",
71+
"//executorch/exir:scalar_type",
72+
],
73+
)
74+
6375
export_file(name = "functions.yaml")
6476

6577
executorch_generated_lib(

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from math import prod
810
from typing import Optional, Tuple
911

@@ -74,8 +76,8 @@ def quantize_per_tensor_meta(
7476
zero_point: int,
7577
quant_min: int,
7678
quant_max: int,
77-
dtype: ScalarType,
78-
):
79+
dtype: torch.dtype,
80+
) -> torch.Tensor:
7981
return input.new_empty(input.size(), dtype=dtype)
8082

8183

@@ -86,8 +88,8 @@ def dequantize_per_tensor_meta(
8688
zero_point: int,
8789
quant_min: int,
8890
quant_max: int,
89-
dtype: ScalarType,
90-
):
91+
dtype: torch.dtype,
92+
) -> torch.Tensor:
9193
return input.new_empty(input.size(), dtype=torch.float)
9294

9395

@@ -102,7 +104,7 @@ def quantized_linear_meta(
102104
out_shift: torch.Tensor,
103105
out_zero_point: int,
104106
offset: Optional[torch.Tensor],
105-
):
107+
) -> torch.Tensor:
106108
# src comes in shape [leading_dims, in_dim]
107109
# weight comes in shape [out_dim, in_dim]
108110
# output comes in empty with shape [leading_dims, out_dim]
@@ -162,7 +164,7 @@ def quantized_layer_norm_meta(
162164
eps: float,
163165
output_scale: float,
164166
output_zero_point: int,
165-
):
167+
) -> torch.Tensor:
166168
return input.new_empty(input.size(), dtype=torch.uint8)
167169

168170

@@ -173,7 +175,7 @@ def quantized_relu_meta(
173175
out_zero_point: int,
174176
out_multiplier: torch.Tensor,
175177
out_shift: torch.Tensor,
176-
):
178+
) -> torch.Tensor:
177179
return X.new_empty(X.size(), dtype=torch.uint8)
178180

179181

0 commit comments

Comments
 (0)