Skip to content

Commit 3607a3a

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Move executorch_prim_ops_registry into exir
Summary: We need prim ops to be registered at exir level to support sym ops. For some reason this piece of code currently lives in executorch/kernels. With the oss plan we need to build package for exir and this target needs to be part of the package. This diff moves the target into exir. Reviewed By: JacobSzwejbka Differential Revision: D47646479 fbshipit-source-id: 378a1751675218920ddb9d87eb481c8cba9665cc
1 parent 03b8b61 commit 3607a3a

13 files changed

+81
-97
lines changed

exir/emit/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ python_library(
3232
"//executorch/exir/dialects/backend:lib",
3333
"//executorch/exir/dialects/edge:lib",
3434
"//executorch/exir/operator:convert",
35+
"//executorch/exir/passes:prim_ops_py_registry",
3536
"//executorch/extension/pytree:pylib",
36-
"//executorch/kernels/prim_ops:prim_to_executorch_ops",
3737
],
3838
)

exir/emit/_emitter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from executorch.exir.dialects.edge._ops import EdgeOpOverload
4141
from executorch.exir.error import ExportError, ExportErrorType, InternalError
4242
from executorch.exir.operator.convert import is_out_variant
43+
from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
4344
from executorch.exir.print_program import pretty_print_stacktraces
4445
from executorch.exir.schema import (
4546
BackendDelegate,
@@ -83,12 +84,10 @@
8384
TensorSpec,
8485
)
8586
from executorch.exir.types import LeafValueSpec, ValueSpec
86-
from executorch.kernels.prim_ops.prim_to_executorch_ops import is_sym_op
8787
from functorch.experimental import control_flow
8888
from torch._export.exported_program import ExportedProgram
8989
from torch.utils import _pytree as pytree
9090

91-
# @manual=fbsource//third-party/pypi/typing-extensions:typing-extensions
9291
from typing_extensions import TypeAlias
9392

9493

exir/passes/TARGETS

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python_library(
1313
":memory_planning_pass",
1414
":normalize_transpose_pass",
1515
":pass_registry",
16+
":prim_ops_py_registry",
1617
":quant_fusion_pass",
1718
":remove_mixed_type_operators",
1819
":remove_noop_pass",
@@ -36,7 +37,6 @@ python_library(
3637
"//executorch/exir/dialects/backend:lib",
3738
"//executorch/exir/dialects/edge:lib",
3839
"//executorch/exir/operator:convert",
39-
"//executorch/kernels/prim_ops:prim_to_executorch_ops",
4040
],
4141
)
4242

@@ -203,10 +203,10 @@ python_library(
203203
"replace_edge_with_backend_pass.py",
204204
],
205205
deps = [
206+
":prim_ops_py_registry",
206207
"//caffe2:torch",
207208
"//executorch/exir:pass_base",
208209
"//executorch/exir/dialects:lib",
209-
"//executorch/kernels/prim_ops:prim_to_executorch_ops",
210210
],
211211
)
212212

@@ -227,11 +227,11 @@ python_library(
227227
"replace_aten_with_edge_pass.py",
228228
],
229229
deps = [
230+
":prim_ops_py_registry",
230231
"//caffe2:torch",
231232
"//executorch/exir:pass_base",
232233
"//executorch/exir/dialects:lib",
233234
"//executorch/exir/dialects/edge:lib",
234-
"//executorch/kernels/prim_ops:prim_to_executorch_ops",
235235
],
236236
)
237237

@@ -245,3 +245,12 @@ python_library(
245245
"//executorch/exir:pass_base",
246246
],
247247
)
248+
249+
python_library(
250+
name = "prim_ops_py_registry",
251+
srcs = ["executorch_prim_ops_registry.py"],
252+
deps = [
253+
"//caffe2:torch",
254+
"//executorch/exir/dialects:lib",
255+
],
256+
)

exir/passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from executorch.exir.pass_manager import PassManager, PassType
2828
from executorch.exir.passes.const_prop_pass import ConstPropPass
2929
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
30+
31+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
3032
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
3133
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
3234
from executorch.exir.passes.pass_registry import PassRegistry
@@ -42,8 +44,6 @@
4244
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
4345
from executorch.exir.passes.spec_prop_pass import SpecPropPass
4446
from executorch.exir.passes.sym_shape_eval_pass import SymShapeEvalPass
45-
46-
from executorch.kernels.prim_ops.prim_to_executorch_ops import _EXECUTORCH_SYM_OPS
4747
from torch import fx
4848
from torch._subclasses import FakeTensor
4949
from torch.fx.passes.infra.pass_base import PassBase, PassResult

exir/passes/dynamic_shape_prop_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from executorch.exir.pass_base import Argument, BackendPass
1515
from executorch.exir.pass_infra.node_metadata import NodeMetadata
1616
from executorch.exir.pass_infra.proxy_value import ProxyValue
17+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
1718
from executorch.exir.schema import TensorShapeDynamism
1819
from executorch.exir.sym_util import collect_free_symbols, eval_expr
1920
from executorch.exir.tensor import TensorSpec
20-
from executorch.kernels.prim_ops.prim_to_executorch_ops import _EXECUTORCH_SYM_OPS
2121
from torch._subclasses import FakeTensor
2222
from torch.fx import GraphModule
2323

kernels/prim_ops/executorch_prim_ops_registry.py renamed to exir/passes/executorch_prim_ops_registry.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
from executorch.exir.dialects._ops import bind_pattern_to_op
1+
import operator
2+
from typing import Dict, Set
3+
4+
# necessary to ensure the ops are registered
5+
import torch
6+
from executorch.exir.dialects._ops import bind_pattern_to_op, ops
27
from torch import SymInt
8+
from torch._ops import OpOverload
39
from torch.library import Library
410

11+
512
executorch_prims_lib = Library("executorch_prim", "DEF")
613

714

@@ -50,3 +57,35 @@ def le(a: SymInt, b: SymInt) -> bool:
5057
@bind_pattern_to_op(executorch_prims_lib, "eq.int(SymInt a, SymInt b) -> bool")
5158
def eq(a: SymInt, b: SymInt) -> bool:
5259
return a == b
60+
61+
62+
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
63+
operator.sub: ops.backend.executorch_prim.sub.int,
64+
operator.mul: ops.backend.executorch_prim.mul.int,
65+
operator.add: ops.backend.executorch_prim.add.int,
66+
operator.floordiv: ops.backend.executorch_prim.floordiv.int,
67+
operator.eq: ops.backend.executorch_prim.eq.int,
68+
operator.gt: ops.backend.executorch_prim.gt.int,
69+
operator.lt: ops.backend.executorch_prim.lt.int,
70+
operator.ge: ops.backend.executorch_prim.ge.int,
71+
operator.le: ops.backend.executorch_prim.le.int,
72+
}
73+
74+
75+
_EXECUTORCH_SYM_OPS: Set[OpOverload] = set(
76+
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values()
77+
)
78+
_EXECUTORCH_SYM_OPS.update(
79+
{
80+
torch.ops.aten.sym_stride.int,
81+
torch.ops.aten.sym_size.int,
82+
torch.ops.aten.sym_numel.default,
83+
}
84+
)
85+
86+
87+
def is_sym_op(target) -> bool:
88+
return (
89+
target in _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.keys()
90+
or target in _EXECUTORCH_SYM_OPS
91+
)

exir/passes/replace_aten_with_edge_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from executorch.exir.dialects._ops import ops
55
from executorch.exir.dialects.edge._ops import EdgeOpOverload
66
from executorch.exir.pass_base import ExportPass
7-
from executorch.kernels.prim_ops.prim_to_executorch_ops import _EXECUTORCH_SYM_OPS
7+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
88
from torch.fx.node import Target
99

1010

exir/passes/replace_edge_with_backend_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from executorch.exir.dialects._ops import ops
55
from executorch.exir.pass_base import ExportPass
6-
from executorch.kernels.prim_ops.prim_to_executorch_ops import (
6+
from executorch.exir.passes.executorch_prim_ops_registry import (
77
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS,
88
)
99

kernels/prim_ops/TARGETS

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
11
# Any targets that should be shared between fbcode and xplat must be defined in
22
# targets.bzl. This file can contain fbcode-only targets.
3-
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
43
load(":targets.bzl", "define_common_targets")
54

65
oncall("executorch")
76

87
define_common_targets()
9-
10-
python_library(
11-
name = "prim_ops_py_registry",
12-
srcs = ["executorch_prim_ops_registry.py"],
13-
deps = [
14-
"//caffe2:torch",
15-
"//executorch/exir/dialects:lib",
16-
],
17-
)
18-
19-
python_library(
20-
name = "prim_to_executorch_ops",
21-
srcs = ["prim_to_executorch_ops.py"],
22-
deps = [
23-
":prim_ops_py_registry",
24-
"//caffe2:torch",
25-
"//executorch/exir/dialects:lib",
26-
],
27-
)

kernels/prim_ops/prim_to_executorch_ops.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

kernels/prim_ops/test/TARGETS

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
2+
13
# Any targets that should be shared between fbcode and xplat must be defined in
24
# targets.bzl. This file can contain fbcode-only targets.
35
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
4-
load(":targets.bzl", "define_common_targets")
56

67
oncall("executorch")
78

8-
define_common_targets()
9-
109
python_unittest(
1110
name = "test_prim_ops",
1211
srcs = [
@@ -15,6 +14,24 @@ python_unittest(
1514
supports_static_listing = True,
1615
deps = [
1716
"//caffe2:torch",
18-
"//executorch/kernels/prim_ops:prim_ops_py_registry",
17+
"//executorch/exir/passes:prim_ops_py_registry",
18+
],
19+
)
20+
21+
cpp_unittest(
22+
name = "register_prim_ops_test",
23+
srcs = [
24+
"register_prim_ops_test.cpp",
25+
],
26+
deps = [
27+
"//executorch/kernels/prim_ops:prim_ops_registry", # @manual
28+
"//executorch/runtime/core:evalue", # @manual
29+
"//executorch/runtime/core/exec_aten:lib", # @manual
30+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util", # @manual
31+
"//executorch/runtime/core/exec_aten/util:tensor_util", # @manual
32+
"//executorch/runtime/kernel:kernel_runtime_context", # @manual
33+
"//executorch/runtime/kernel:operator_registry",
34+
"//executorch/runtime/platform:platform",
35+
"//executorch/test/utils:utils",
1936
],
2037
)

kernels/prim_ops/test/targets.bzl

Lines changed: 0 additions & 20 deletions
This file was deleted.

kernels/prim_ops/test/test_prim_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
# necessary to ensure the ops are registered
4-
import executorch.kernels.prim_ops.executorch_prim_ops_registry
4+
import executorch.exir.passes.executorch_prim_ops_registry # noqa: F401
55

66
import torch
77

0 commit comments

Comments
 (0)