Skip to content

Commit ed96fd7

Browse files
committed
Update base for Update on "[ET-VK] Making stride equals dilation the default mode for conv2d dw."
This diff makes changes make stride equals dilation the default mode for conv2d dw output op. Adds a different source file to handle stride not equal dilation case. Differential Revision: [D67979760](https://our.internmc.facebook.com/intern/diff/D67979760/) [ghstack-poisoned]
2 parents fdc6bf1 + 94d83ad commit ed96fd7

35 files changed

+2372
-321
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
url = https://github.com/pybind/pybind11.git
6767
[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"]
6868
path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3
69-
url = https://github.com/foss-xtensa/nnlib-FusionG3/
69+
url = https://github.com/foss-xtensa/nnlib-FusionG3.git
7070
[submodule "third-party/ao"]
7171
path = third-party/ao
7272
url = https://github.com/pytorch/ao.git

backends/apple/mps/mps_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
CompileSpec,
3333
PreprocessResult,
3434
)
35+
36+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
37+
from executorch.exir.program._program import _transform
3538
from torch.export.exported_program import ExportedProgram
3639

3740
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +86,9 @@ def preprocess(
8386
# FlatBuffer graph, process the `output` nodes and add their id to
8487
# the `output_ids` array in the schema.
8588

89+
# TODO: Remove this once we have a better support for the dim-order ops.
90+
edge_program = _transform(edge_program, DimOrderOpsRevertPass())
91+
8692
mps_graph = MPSGraph(
8793
version="0",
8894
mps_nodes=[],

backends/apple/mps/operators/constant_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def define_node(
7979
)
8080

8181

82+
@register_node_visitor
83+
class ToDimOrderEmptyVisitor(NodeVisitor):
84+
target = ["dim_order_ops._empty_dim_order.default"]
85+
86+
def __init__(self, *args) -> None:
87+
super().__init__(*args)
88+
89+
def define_node(
90+
self,
91+
node: torch.fx.Node,
92+
mps_graph: MPSGraph,
93+
) -> None:
94+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
95+
# But if we do, we can't handle it ATM, so raise an exception
96+
raise NotImplementedError(
97+
"dim_order_ops._empty_dim_order.default is not supported yet"
98+
)
99+
100+
82101
@register_node_visitor
83102
class FullLikeVisitor(NodeVisitor):
84103
target = "aten.full_like.default"

backends/apple/mps/operators/op_clone.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,22 @@ def define_node(
3333
)
3434
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
3535
self.tensor_to_id[node] = input_id
36+
37+
38+
@register_node_visitor
39+
class ToDimOrderCopyVisitor(NodeVisitor):
40+
target = ["dim_order_ops._to_dim_order_copy.default"]
41+
42+
def __init__(self, *args) -> None:
43+
super().__init__(*args)
44+
45+
def define_node(
46+
self,
47+
node: torch.fx.Node,
48+
mps_graph: MPSGraph,
49+
) -> None:
50+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
51+
# But if we do, we can't handle it ATM, so raise an exception
52+
raise NotImplementedError(
53+
"dim_order_ops._to_dim_order_copy.default is not supported yet"
54+
)

backends/apple/mps/test/test_mps.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,21 @@ def forward(self, x):
18291829
Clone(), model_inputs, func_name=inspect.stack()[0].function[5:]
18301830
)
18311831

1832+
def test_mps_backend_to_copy(self):
1833+
class Copy(torch.nn.Module):
1834+
def forward(self, x):
1835+
return (
1836+
torch.ops.aten._to_copy.default(
1837+
x + 2, memory_format=torch.contiguous_format
1838+
)
1839+
+ x
1840+
)
1841+
1842+
model_inputs = (torch.randn(1, 3, 3),)
1843+
self.lower_and_test_with_partitioner(
1844+
Copy(), model_inputs, func_name=inspect.stack()[0].function[5:]
1845+
)
1846+
18321847
def test_mps_backend_floor(self):
18331848
class Floor(torch.nn.Module):
18341849
def forward(self, x):

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626

2727
# Config for Capturing the weights, will be moved in the future
2828

29-
# TODO(T182928844): Delegate dim order op to backend.
30-
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
31-
_check_ir_validity=False, _skip_dim_order=True
32-
)
29+
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False)
3330

3431

3532
class ansi_colors:
@@ -219,7 +216,6 @@ def lower_module_and_test_output(
219216
dynamic_shapes=dynamic_shapes,
220217
edge_compile_config=EdgeCompileConfig(
221218
_check_ir_validity=False,
222-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
223219
),
224220
)
225221

@@ -250,7 +246,6 @@ def lower_module_and_test_output(
250246
export(delegated_program, sample_inputs, strict=True),
251247
compile_config=exir.EdgeCompileConfig(
252248
_check_ir_validity=False,
253-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
254249
),
255250
).to_executorch(
256251
config=ExecutorchBackendConfig(extract_delegate_segments=False)

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@
5050
- op: div.out
5151
kernels:
5252
- arg_meta: null
53-
kernel_name: torch::executor::div_out
53+
kernel_name: cadence::impl::G3::div_out
5454

5555
- op: div.out_mode
5656
kernels:
5757
- arg_meta: null
58-
kernel_name: torch::executor::div_out_mode
58+
kernel_name: cadence::impl::G3::div_out_mode
5959

6060
- op: embedding.out
6161
kernels:
@@ -71,7 +71,6 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
7574
- op: mul.Scalar_out
7675
kernels:
7776
- arg_meta: null
@@ -80,7 +79,7 @@
8079
- op: permute_copy.out
8180
kernels:
8281
- arg_meta: null
83-
kernel_name: torch::executor::permute_copy_out
82+
kernel_name: cadence::impl::G3::permute_copy_out
8483

8584
- op: sigmoid.out
8685
kernels:
@@ -90,7 +89,7 @@
9089
- op: slice_copy.Tensor_out
9190
kernels:
9291
- arg_meta: null
93-
kernel_name: torch::executor::slice_copy_Tensor_out
92+
kernel_name: cadence::impl::G3::slice_copy_Tensor_out
9493

9594
- op: split_with_sizes_copy.out
9695
kernels:
@@ -100,7 +99,12 @@
10099
- op: sub.out
101100
kernels:
102101
- arg_meta: null
103-
kernel_name: torch::executor::sub_out
102+
kernel_name: cadence::impl::G3::sub_out
103+
104+
- op: sub.Scalar_out
105+
kernels:
106+
- arg_meta: null
107+
kernel_name: cadence::impl::G3::sub_scalar_out
104108

105109
- op: view_copy.out
106110
kernels:
@@ -117,6 +121,16 @@
117121
- arg_meta: null
118122
kernel_name: cadence::impl::G3::native_layer_norm_out
119123

124+
- op: mean.out
125+
kernels:
126+
- arg_meta: null
127+
kernel_name: cadence::impl::G3::mean_dim_out
128+
129+
- op: exp.out
130+
kernels:
131+
- arg_meta: null
132+
kernel_name: cadence::impl::G3::exp_out
133+
120134
# custom ops
121135
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122136
variants: function

backends/cadence/fusion_g3/operators/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ set(_aten_ops__srcs
3636
"${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp"
3737
"${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp"
3838
"${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp"
39+
"${CMAKE_CURRENT_SOURCE_DIR}/op_sub.cpp"
40+
"${CMAKE_CURRENT_SOURCE_DIR}/op_div.cpp"
41+
"${CMAKE_CURRENT_SOURCE_DIR}/op_mean.cpp"
42+
"${CMAKE_CURRENT_SOURCE_DIR}/op_slice_copy.cpp"
43+
"${CMAKE_CURRENT_SOURCE_DIR}/op_permute_copy.cpp"
44+
"${CMAKE_CURRENT_SOURCE_DIR}/op_exp.cpp"
3945
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
4046
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
4147
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
@@ -51,6 +57,7 @@ set(_aten_ops__srcs
5157
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
5258
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp"
5359
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp"
60+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
5461
)
5562
add_library(aten_ops_cadence ${_aten_ops__srcs})
5663
target_link_libraries(aten_ops_cadence PUBLIC executorch)

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Tensor& add_out(
3939
ScalarType common_type =
4040
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
4141

42+
#ifdef OP_ARG_CHECK
4243
// Check Common Dtype
4344
ET_KERNEL_CHECK(
4445
ctx,
@@ -62,12 +63,12 @@ Tensor& add_out(
6263
torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok,
6364
InvalidArgument,
6465
out);
66+
#endif
6567

6668
// Compute Dtype
6769
ScalarType compute_type =
6870
torch::executor::native::utils::get_compute_type(common_type);
6971

70-
// @lint-ignore CLANGTIDY facebook-hte-CArray
7172
static constexpr const char op_name[] = "add.out";
7273

7374
int kTensorDimensionLimit = 5;
@@ -253,6 +254,7 @@ Tensor& add_scalar_out(
253254
torch::executor::native::utils::promote_type_with_scalar(
254255
a.scalar_type(), b);
255256

257+
#ifdef OP_ARG_CHECK
256258
// Check Common Dtype
257259
ET_KERNEL_CHECK(
258260
ctx,
@@ -276,7 +278,7 @@ Tensor& add_scalar_out(
276278
executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok,
277279
InvalidArgument,
278280
out);
279-
281+
#endif
280282
// Compute Dtype
281283
ScalarType compute_type =
282284
torch::executor::native::utils::get_compute_type(common_type);

0 commit comments

Comments
 (0)