Skip to content

Commit 5048523

Browse files
committed
Update on "[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue."
This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Additional test cases are added to ensure computation is correct when stride != dilation. Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/) [ghstack-poisoned]
2 parents de44eda + 0cd0fb9 commit 5048523

35 files changed

+2372
-321
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
url = https://github.com/pybind/pybind11.git
6767
[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"]
6868
path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3
69-
url = https://github.com/foss-xtensa/nnlib-FusionG3/
69+
url = https://github.com/foss-xtensa/nnlib-FusionG3.git
7070
[submodule "third-party/ao"]
7171
path = third-party/ao
7272
url = https://github.com/pytorch/ao.git

backends/apple/mps/mps_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
CompileSpec,
3333
PreprocessResult,
3434
)
35+
36+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
37+
from executorch.exir.program._program import _transform
3538
from torch.export.exported_program import ExportedProgram
3639

3740
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +86,9 @@ def preprocess(
8386
# FlatBuffer graph, process the `output` nodes and add their id to
8487
# the `output_ids` array in the schema.
8588

89+
# TODO: Remove this once we have a better support for the dim-order ops.
90+
edge_program = _transform(edge_program, DimOrderOpsRevertPass())
91+
8692
mps_graph = MPSGraph(
8793
version="0",
8894
mps_nodes=[],

backends/apple/mps/operators/constant_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def define_node(
7979
)
8080

8181

82+
@register_node_visitor
83+
class ToDimOrderEmptyVisitor(NodeVisitor):
84+
target = ["dim_order_ops._empty_dim_order.default"]
85+
86+
def __init__(self, *args) -> None:
87+
super().__init__(*args)
88+
89+
def define_node(
90+
self,
91+
node: torch.fx.Node,
92+
mps_graph: MPSGraph,
93+
) -> None:
94+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
95+
# But if we do, we can't handle it ATM, so raise an exception
96+
raise NotImplementedError(
97+
"dim_order_ops._empty_dim_order.default is not supported yet"
98+
)
99+
100+
82101
@register_node_visitor
83102
class FullLikeVisitor(NodeVisitor):
84103
target = "aten.full_like.default"

backends/apple/mps/operators/op_clone.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,22 @@ def define_node(
3333
)
3434
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
3535
self.tensor_to_id[node] = input_id
36+
37+
38+
@register_node_visitor
39+
class ToDimOrderCopyVisitor(NodeVisitor):
40+
target = ["dim_order_ops._to_dim_order_copy.default"]
41+
42+
def __init__(self, *args) -> None:
43+
super().__init__(*args)
44+
45+
def define_node(
46+
self,
47+
node: torch.fx.Node,
48+
mps_graph: MPSGraph,
49+
) -> None:
50+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
51+
# But if we do, we can't handle it ATM, so raise an exception
52+
raise NotImplementedError(
53+
"dim_order_ops._to_dim_order_copy.default is not supported yet"
54+
)

backends/apple/mps/test/test_mps.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,21 @@ def forward(self, x):
18291829
Clone(), model_inputs, func_name=inspect.stack()[0].function[5:]
18301830
)
18311831

1832+
def test_mps_backend_to_copy(self):
1833+
class Copy(torch.nn.Module):
1834+
def forward(self, x):
1835+
return (
1836+
torch.ops.aten._to_copy.default(
1837+
x + 2, memory_format=torch.contiguous_format
1838+
)
1839+
+ x
1840+
)
1841+
1842+
model_inputs = (torch.randn(1, 3, 3),)
1843+
self.lower_and_test_with_partitioner(
1844+
Copy(), model_inputs, func_name=inspect.stack()[0].function[5:]
1845+
)
1846+
18321847
def test_mps_backend_floor(self):
18331848
class Floor(torch.nn.Module):
18341849
def forward(self, x):

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626

2727
# Config for Capturing the weights, will be moved in the future
2828

29-
# TODO(T182928844): Delegate dim order op to backend.
30-
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
31-
_check_ir_validity=False, _skip_dim_order=True
32-
)
29+
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False)
3330

3431

3532
class ansi_colors:
@@ -219,7 +216,6 @@ def lower_module_and_test_output(
219216
dynamic_shapes=dynamic_shapes,
220217
edge_compile_config=EdgeCompileConfig(
221218
_check_ir_validity=False,
222-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
223219
),
224220
)
225221

@@ -250,7 +246,6 @@ def lower_module_and_test_output(
250246
export(delegated_program, sample_inputs, strict=True),
251247
compile_config=exir.EdgeCompileConfig(
252248
_check_ir_validity=False,
253-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
254249
),
255250
).to_executorch(
256251
config=ExecutorchBackendConfig(extract_delegate_segments=False)

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@
5050
- op: div.out
5151
kernels:
5252
- arg_meta: null
53-
kernel_name: torch::executor::div_out
53+
kernel_name: cadence::impl::G3::div_out
5454

5555
- op: div.out_mode
5656
kernels:
5757
- arg_meta: null
58-
kernel_name: torch::executor::div_out_mode
58+
kernel_name: cadence::impl::G3::div_out_mode
5959

6060
- op: embedding.out
6161
kernels:
@@ -71,7 +71,6 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
7574
- op: mul.Scalar_out
7675
kernels:
7776
- arg_meta: null
@@ -80,7 +79,7 @@
8079
- op: permute_copy.out
8180
kernels:
8281
- arg_meta: null
83-
kernel_name: torch::executor::permute_copy_out
82+
kernel_name: cadence::impl::G3::permute_copy_out
8483

8584
- op: sigmoid.out
8685
kernels:
@@ -90,7 +89,7 @@
9089
- op: slice_copy.Tensor_out
9190
kernels:
9291
- arg_meta: null
93-
kernel_name: torch::executor::slice_copy_Tensor_out
92+
kernel_name: cadence::impl::G3::slice_copy_Tensor_out
9493

9594
- op: split_with_sizes_copy.out
9695
kernels:
@@ -100,7 +99,12 @@
10099
- op: sub.out
101100
kernels:
102101
- arg_meta: null
103-
kernel_name: torch::executor::sub_out
102+
kernel_name: cadence::impl::G3::sub_out
103+
104+
- op: sub.Scalar_out
105+
kernels:
106+
- arg_meta: null
107+
kernel_name: cadence::impl::G3::sub_scalar_out
104108

105109
- op: view_copy.out
106110
kernels:
@@ -117,6 +121,16 @@
117121
- arg_meta: null
118122
kernel_name: cadence::impl::G3::native_layer_norm_out
119123

124+
- op: mean.out
125+
kernels:
126+
- arg_meta: null
127+
kernel_name: cadence::impl::G3::mean_dim_out
128+
129+
- op: exp.out
130+
kernels:
131+
- arg_meta: null
132+
kernel_name: cadence::impl::G3::exp_out
133+
120134
# custom ops
121135
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122136
variants: function

backends/cadence/fusion_g3/operators/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ set(_aten_ops__srcs
3636
"${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp"
3737
"${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp"
3838
"${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp"
39+
"${CMAKE_CURRENT_SOURCE_DIR}/op_sub.cpp"
40+
"${CMAKE_CURRENT_SOURCE_DIR}/op_div.cpp"
41+
"${CMAKE_CURRENT_SOURCE_DIR}/op_mean.cpp"
42+
"${CMAKE_CURRENT_SOURCE_DIR}/op_slice_copy.cpp"
43+
"${CMAKE_CURRENT_SOURCE_DIR}/op_permute_copy.cpp"
44+
"${CMAKE_CURRENT_SOURCE_DIR}/op_exp.cpp"
3945
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
4046
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
4147
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
@@ -51,6 +57,7 @@ set(_aten_ops__srcs
5157
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
5258
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp"
5359
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp"
60+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
5461
)
5562
add_library(aten_ops_cadence ${_aten_ops__srcs})
5663
target_link_libraries(aten_ops_cadence PUBLIC executorch)

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Tensor& add_out(
3939
ScalarType common_type =
4040
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
4141

42+
#ifdef OP_ARG_CHECK
4243
// Check Common Dtype
4344
ET_KERNEL_CHECK(
4445
ctx,
@@ -62,12 +63,12 @@ Tensor& add_out(
6263
torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok,
6364
InvalidArgument,
6465
out);
66+
#endif
6567

6668
// Compute Dtype
6769
ScalarType compute_type =
6870
torch::executor::native::utils::get_compute_type(common_type);
6971

70-
// @lint-ignore CLANGTIDY facebook-hte-CArray
7172
static constexpr const char op_name[] = "add.out";
7273

7374
int kTensorDimensionLimit = 5;
@@ -253,6 +254,7 @@ Tensor& add_scalar_out(
253254
torch::executor::native::utils::promote_type_with_scalar(
254255
a.scalar_type(), b);
255256

257+
#ifdef OP_ARG_CHECK
256258
// Check Common Dtype
257259
ET_KERNEL_CHECK(
258260
ctx,
@@ -276,7 +278,7 @@ Tensor& add_scalar_out(
276278
executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok,
277279
InvalidArgument,
278280
out);
279-
281+
#endif
280282
// Compute Dtype
281283
ScalarType compute_type =
282284
torch::executor::native::utils::get_compute_type(common_type);

0 commit comments

Comments
 (0)