Skip to content

Commit d2014e3

Browse files
authored
Add a target rule for ops_registrations (#5083)
Differential Revision: D62206605 Pull Request resolved: #5191
1 parent 85410e4 commit d2014e3

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

backends/cadence/aot/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ python_library(
6060
],
6161
)
6262

63+
python_library(
64+
name = "ops_registrations",
65+
srcs = [
66+
"ops_registrations.py",
67+
],
68+
deps = [
69+
"fbcode//caffe2:torch",
70+
"fbcode//executorch/backends/cadence/aot:utils",
71+
],
72+
)
73+
6374
export_file(name = "functions.yaml")
6475

6576
executorch_generated_lib(

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from math import prod
810
from typing import Optional, Tuple
911

1012
import torch
11-
from executorch.exir.scalar_type import ScalarType
1213
from torch.library import impl, Library
1314

1415
from .utils import get_conv1d_output_size, get_conv2d_output_size
@@ -74,8 +75,8 @@ def quantize_per_tensor_meta(
7475
zero_point: int,
7576
quant_min: int,
7677
quant_max: int,
77-
dtype: ScalarType,
78-
):
78+
dtype: torch.dtype,
79+
) -> torch.Tensor:
7980
return input.new_empty(input.size(), dtype=dtype)
8081

8182

@@ -86,8 +87,8 @@ def dequantize_per_tensor_meta(
8687
zero_point: int,
8788
quant_min: int,
8889
quant_max: int,
89-
dtype: ScalarType,
90-
):
90+
dtype: torch.dtype,
91+
) -> torch.Tensor:
9192
return input.new_empty(input.size(), dtype=torch.float)
9293

9394

@@ -102,7 +103,7 @@ def quantized_linear_meta(
102103
out_shift: torch.Tensor,
103104
out_zero_point: int,
104105
offset: Optional[torch.Tensor],
105-
):
106+
) -> torch.Tensor:
106107
# src comes in shape [leading_dims, in_dim]
107108
# weight comes in shape [out_dim, in_dim]
108109
# output comes in empty with shape [leading_dims, out_dim]
@@ -162,7 +163,7 @@ def quantized_layer_norm_meta(
162163
eps: float,
163164
output_scale: float,
164165
output_zero_point: int,
165-
):
166+
) -> torch.Tensor:
166167
return input.new_empty(input.size(), dtype=torch.uint8)
167168

168169

@@ -173,7 +174,7 @@ def quantized_relu_meta(
173174
out_zero_point: int,
174175
out_multiplier: torch.Tensor,
175176
out_shift: torch.Tensor,
176-
):
177+
) -> torch.Tensor:
177178
return X.new_empty(X.size(), dtype=torch.uint8)
178179

179180

0 commit comments

Comments
 (0)