Skip to content

Commit f87940d

Browse files
authored
Update buck deps for new replace_scalar_with_tensor transforms
Differential Revision: D69883148 Pull Request resolved: #8588
1 parent 2b81e6f commit f87940d

File tree

6 files changed

+30
-14
lines changed

6 files changed

+30
-14
lines changed

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
deps = [
88
"//executorch/backends/arm:tosa_quant_utils",
99
"//executorch/backends/arm:tosa_utils",
10+
"//executorch/backends/transforms:replace_scalar_with_tensor",
1011
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1112
"//executorch/exir:lib",
1213
],

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ python_library(
256256
"//executorch/backends/cadence/aot:pass_utils",
257257
"//executorch/backends/cadence/aot:remove_ops",
258258
"//executorch/backends/cadence/aot:utils",
259+
"//executorch/backends/transforms:replace_scalar_with_tensor",
259260
"//executorch/exir:pass_base",
260261
"//executorch/exir/dialects:lib",
261262
"//executorch/exir/dialects/edge:lib",

backends/cadence/aot/pass_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Callable, List, Optional, Set, Union
10+
from typing import Callable, List, Optional, Set, Type, Union
1111

1212
import torch
1313
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
3232

3333

3434
# A dictionary that maps an ExportPass to its attributes.
35-
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
35+
ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {}
3636

3737

38-
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
38+
def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute:
3939
return ALL_CADENCE_PASSES[p]
4040

4141

4242
# A decorator that registers a pass.
4343
def register_cadence_pass(
4444
pass_attribute: CadencePassAttribute,
45-
) -> Callable[[ExportPass], ExportPass]:
46-
def wrapper(cls: ExportPass) -> ExportPass:
45+
) -> Callable[[Type[ExportPass]], Type[ExportPass]]:
46+
def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]:
4747
ALL_CADENCE_PASSES[cls] = pass_attribute
4848
return cls
4949

5050
return wrapper
5151

5252

53-
def get_all_available_cadence_passes() -> Set[ExportPass]:
53+
def get_all_available_cadence_passes() -> Set[Type[ExportPass]]:
5454
return set(ALL_CADENCE_PASSES.keys())
5555

5656

5757
# Create a new filter to filter out relevant passes from all passes.
5858
def create_cadence_pass_filter(
5959
opt_level: int, debug: bool = False
60-
) -> Callable[[ExportPass], bool]:
61-
def _filter(p: ExportPass) -> bool:
60+
) -> Callable[[Type[ExportPass]], bool]:
61+
def _filter(p: Type[ExportPass]) -> bool:
6262
pass_attribute = get_cadence_pass_attribute(p)
6363
return (
6464
pass_attribute.opt_level is not None

backends/cadence/aot/passes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional, Type
9+
from typing import Any, cast, List, Optional, Type
1010

1111
import torch
1212
import torch.fx
@@ -95,9 +95,9 @@ def get_cadence_passes(
9595
passes = get_passes_in_default_order()
9696
pass_filter = create_cadence_pass_filter(opt_level)
9797
filtered_passes = [
98-
# pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
9998
filtered_pass()
10099
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
101100
for filtered_pass in list(filter(pass_filter, passes))
102101
]
103-
return filtered_passes
102+
# The type checker can't infer the proper type of the list comprehension.
103+
return cast(List[Optional[PassResult]], filtered_passes)

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,9 +1719,9 @@ def call_operator(self, op, args, kwargs, meta):
17191719
)
17201720

17211721

1722-
@register_cadence_pass(CadencePassAttribute(opt_level=0))(
1723-
ReplaceScalarWithTensorArgPass()
1724-
)
1722+
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)
1723+
1724+
17251725
@register_cadence_pass(CadencePassAttribute(opt_level=0))
17261726
class ReplaceScalarTensorWithFullPass(ExportPass):
17271727
"""

backends/transforms/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def define_common_targets():
201201
],
202202
)
203203

204+
runtime.python_library(
205+
name = "replace_scalar_with_tensor",
206+
srcs = [
207+
"replace_scalar_with_tensor.py",
208+
],
209+
visibility = [
210+
"//executorch/backends/...",
211+
],
212+
deps = [
213+
"//caffe2:torch",
214+
"//executorch/exir:pass_base",
215+
],
216+
)
217+
204218
runtime.python_test(
205219
name = "test_duplicate_dynamic_quant_chain",
206220
srcs = [

0 commit comments

Comments
 (0)