Skip to content

Commit 91945f6

Browse files
committed
Update base for Update on "[ExecuTorch] Add broadcasting support to optimized op_div"
Summary: Similar to broadcast support in op_mul Test Plan: Tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491815](https://our.internmc.facebook.com/intern/diff/D69491815) [ghstack-poisoned]
2 parents a4a55a9 + 5cf0106 commit 91945f6

File tree

18 files changed

+1085
-22
lines changed

18 files changed

+1085
-22
lines changed

backends/apple/coreml/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ runtime.cxx_python_extension(
7676
base_module = "",
7777
visibility = [
7878
"//executorch/examples/apple/coreml/...",
79+
"@EXECUTORCH_CLIENTS",
7980
],
8081
external_deps = [
8182
"pybind11",

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
deps = [
88
"//executorch/backends/arm:tosa_quant_utils",
99
"//executorch/backends/arm:tosa_utils",
10+
"//executorch/backends/transforms:replace_scalar_with_tensor",
1011
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1112
"//executorch/exir:lib",
1213
],

backends/arm/operator_support/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ python_library(
55
srcs = glob(["*.py"]),
66
typing = True,
77
deps = [
8+
"//executorch/backends/arm/_passes:passes",
9+
"//executorch/backends/arm:tosa_specification",
810
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
911
"//executorch/exir:lib",
10-
"//executorch/backends/arm:tosa_specification"
1112
],
1213
)

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ python_library(
256256
"//executorch/backends/cadence/aot:pass_utils",
257257
"//executorch/backends/cadence/aot:remove_ops",
258258
"//executorch/backends/cadence/aot:utils",
259+
"//executorch/backends/transforms:replace_scalar_with_tensor",
259260
"//executorch/exir:pass_base",
260261
"//executorch/exir/dialects:lib",
261262
"//executorch/exir/dialects/edge:lib",

backends/cadence/aot/pass_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Callable, List, Optional, Set, Union
10+
from typing import Callable, List, Optional, Set, Type, Union
1111

1212
import torch
1313
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
3232

3333

3434
# A dictionary that maps an ExportPass to its attributes.
35-
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
35+
ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {}
3636

3737

38-
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
38+
def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute:
3939
return ALL_CADENCE_PASSES[p]
4040

4141

4242
# A decorator that registers a pass.
4343
def register_cadence_pass(
4444
pass_attribute: CadencePassAttribute,
45-
) -> Callable[[ExportPass], ExportPass]:
46-
def wrapper(cls: ExportPass) -> ExportPass:
45+
) -> Callable[[Type[ExportPass]], Type[ExportPass]]:
46+
def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]:
4747
ALL_CADENCE_PASSES[cls] = pass_attribute
4848
return cls
4949

5050
return wrapper
5151

5252

53-
def get_all_available_cadence_passes() -> Set[ExportPass]:
53+
def get_all_available_cadence_passes() -> Set[Type[ExportPass]]:
5454
return set(ALL_CADENCE_PASSES.keys())
5555

5656

5757
# Create a new filter to filter out relevant passes from all passes.
5858
def create_cadence_pass_filter(
5959
opt_level: int, debug: bool = False
60-
) -> Callable[[ExportPass], bool]:
61-
def _filter(p: ExportPass) -> bool:
60+
) -> Callable[[Type[ExportPass]], bool]:
61+
def _filter(p: Type[ExportPass]) -> bool:
6262
pass_attribute = get_cadence_pass_attribute(p)
6363
return (
6464
pass_attribute.opt_level is not None

backends/cadence/aot/passes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional, Type
9+
from typing import Any, cast, List, Optional, Type
1010

1111
import torch
1212
import torch.fx
@@ -95,9 +95,9 @@ def get_cadence_passes(
9595
passes = get_passes_in_default_order()
9696
pass_filter = create_cadence_pass_filter(opt_level)
9797
filtered_passes = [
98-
# pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
9998
filtered_pass()
10099
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
101100
for filtered_pass in list(filter(pass_filter, passes))
102101
]
103-
return filtered_passes
102+
# The type checker can't infer the proper type of the list comprehension.
103+
return cast(List[Optional[PassResult]], filtered_passes)

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,9 +1719,9 @@ def call_operator(self, op, args, kwargs, meta):
17191719
)
17201720

17211721

1722-
@register_cadence_pass(CadencePassAttribute(opt_level=0))(
1723-
ReplaceScalarWithTensorArgPass()
1724-
)
1722+
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)
1723+
1724+
17251725
@register_cadence_pass(CadencePassAttribute(opt_level=0))
17261726
class ReplaceScalarTensorWithFullPass(ExportPass):
17271727
"""

backends/transforms/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def define_common_targets():
201201
],
202202
)
203203

204+
runtime.python_library(
205+
name = "replace_scalar_with_tensor",
206+
srcs = [
207+
"replace_scalar_with_tensor.py",
208+
],
209+
visibility = [
210+
"//executorch/backends/...",
211+
],
212+
deps = [
213+
"//caffe2:torch",
214+
"//executorch/exir:pass_base",
215+
],
216+
)
217+
204218
runtime.python_test(
205219
name = "test_duplicate_dynamic_quant_chain",
206220
srcs = [

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,12 @@ void add_conv1d_node(
475475
const ValueRef out,
476476
const bool clamp_out) {
477477
ValueRef arg_weight = prepack_standard(
478-
graph, weight, graph.storage_type_of(out), utils::kChannelsPacked);
478+
graph,
479+
weight,
480+
graph.storage_type_of(out),
481+
utils::kChannelsPacked,
482+
/* passthrough = */ false,
483+
utils::kOptimizedAxisMap);
479484
ValueRef arg_bias = prepack_biases(
480485
graph,
481486
bias,

build/Utils.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ function(add_torch_to_cmake_prefix_path)
357357
endif()
358358
execute_process(
359359
COMMAND "${PYTHON_EXECUTABLE}" -c
360-
"import torch as _; print(_.__path__[0], end='')"
360+
"import importlib.util; print(importlib.util.find_spec('torch').submodule_search_locations[0])"
361361
OUTPUT_VARIABLE _tmp_torch_path
362362
ERROR_VARIABLE _tmp_torch_path_error
363363
RESULT_VARIABLE _tmp_torch_path_result COMMAND_ECHO STDERR

0 commit comments

Comments
 (0)