Skip to content

Commit 94d83ad

Browse files
authored
Cadence fusiong3 operators m2
Differential Revision: D67870337 Pull Request resolved: #7490
1 parent ce3f4f6 commit 94d83ad

19 files changed

+2184
-308
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
url = https://github.com/pybind/pybind11.git
6767
[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"]
6868
path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3
69-
url = https://github.com/foss-xtensa/nnlib-FusionG3/
69+
url = https://github.com/foss-xtensa/nnlib-FusionG3.git
7070
[submodule "third-party/ao"]
7171
path = third-party/ao
7272
url = https://github.com/pytorch/ao.git

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@
5050
- op: div.out
5151
kernels:
5252
- arg_meta: null
53-
kernel_name: torch::executor::div_out
53+
kernel_name: cadence::impl::G3::div_out
5454

5555
- op: div.out_mode
5656
kernels:
5757
- arg_meta: null
58-
kernel_name: torch::executor::div_out_mode
58+
kernel_name: cadence::impl::G3::div_out_mode
5959

6060
- op: embedding.out
6161
kernels:
@@ -71,7 +71,6 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
7574
- op: mul.Scalar_out
7675
kernels:
7776
- arg_meta: null
@@ -80,7 +79,7 @@
8079
- op: permute_copy.out
8180
kernels:
8281
- arg_meta: null
83-
kernel_name: torch::executor::permute_copy_out
82+
kernel_name: cadence::impl::G3::permute_copy_out
8483

8584
- op: sigmoid.out
8685
kernels:
@@ -90,7 +89,7 @@
9089
- op: slice_copy.Tensor_out
9190
kernels:
9291
- arg_meta: null
93-
kernel_name: torch::executor::slice_copy_Tensor_out
92+
kernel_name: cadence::impl::G3::slice_copy_Tensor_out
9493

9594
- op: split_with_sizes_copy.out
9695
kernels:
@@ -100,7 +99,12 @@
10099
- op: sub.out
101100
kernels:
102101
- arg_meta: null
103-
kernel_name: torch::executor::sub_out
102+
kernel_name: cadence::impl::G3::sub_out
103+
104+
- op: sub.Scalar_out
105+
kernels:
106+
- arg_meta: null
107+
kernel_name: cadence::impl::G3::sub_scalar_out
104108

105109
- op: view_copy.out
106110
kernels:
@@ -117,6 +121,16 @@
117121
- arg_meta: null
118122
kernel_name: cadence::impl::G3::native_layer_norm_out
119123

124+
- op: mean.out
125+
kernels:
126+
- arg_meta: null
127+
kernel_name: cadence::impl::G3::mean_dim_out
128+
129+
- op: exp.out
130+
kernels:
131+
- arg_meta: null
132+
kernel_name: cadence::impl::G3::exp_out
133+
120134
# custom ops
121135
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122136
variants: function

backends/cadence/fusion_g3/operators/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ set(_aten_ops__srcs
3636
"${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp"
3737
"${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp"
3838
"${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp"
39+
"${CMAKE_CURRENT_SOURCE_DIR}/op_sub.cpp"
40+
"${CMAKE_CURRENT_SOURCE_DIR}/op_div.cpp"
41+
"${CMAKE_CURRENT_SOURCE_DIR}/op_mean.cpp"
42+
"${CMAKE_CURRENT_SOURCE_DIR}/op_slice_copy.cpp"
43+
"${CMAKE_CURRENT_SOURCE_DIR}/op_permute_copy.cpp"
44+
"${CMAKE_CURRENT_SOURCE_DIR}/op_exp.cpp"
3945
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
4046
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
4147
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
@@ -51,6 +57,7 @@ set(_aten_ops__srcs
5157
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
5258
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp"
5359
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp"
60+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
5461
)
5562
add_library(aten_ops_cadence ${_aten_ops__srcs})
5663
target_link_libraries(aten_ops_cadence PUBLIC executorch)

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Tensor& add_out(
3939
ScalarType common_type =
4040
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
4141

42+
#ifdef OP_ARG_CHECK
4243
// Check Common Dtype
4344
ET_KERNEL_CHECK(
4445
ctx,
@@ -62,12 +63,12 @@ Tensor& add_out(
6263
torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok,
6364
InvalidArgument,
6465
out);
66+
#endif
6567

6668
// Compute Dtype
6769
ScalarType compute_type =
6870
torch::executor::native::utils::get_compute_type(common_type);
6971

70-
// @lint-ignore CLANGTIDY facebook-hte-CArray
7172
static constexpr const char op_name[] = "add.out";
7273

7374
int kTensorDimensionLimit = 5;
@@ -253,6 +254,7 @@ Tensor& add_scalar_out(
253254
torch::executor::native::utils::promote_type_with_scalar(
254255
a.scalar_type(), b);
255256

257+
#ifdef OP_ARG_CHECK
256258
// Check Common Dtype
257259
ET_KERNEL_CHECK(
258260
ctx,
@@ -276,7 +278,7 @@ Tensor& add_scalar_out(
276278
executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok,
277279
InvalidArgument,
278280
out);
279-
281+
#endif
280282
// Compute Dtype
281283
ScalarType compute_type =
282284
torch::executor::native::utils::get_compute_type(common_type);

backends/cadence/fusion_g3/operators/op_cat.cpp

Lines changed: 36 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/cadence/fusion_g3/operators/operators.h>
10+
#include <executorch/backends/cadence/fusion_g3/operators/xt_utils.h>
11+
912
#include <cstring>
1013

1114
#include <xa_nnlib_kernels_api.h>
1215

16+
#include <executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
1317
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1418
#include <executorch/runtime/kernel/kernel_includes.h>
1519

20+
using ::executorch::aten::ArrayRef;
1621
using ::executorch::aten::ScalarType;
1722
using ::executorch::aten::Tensor;
1823
using ::executorch::runtime::Error;
@@ -23,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext;
2328
* updated to have support for below data types, these can be removed and
2429
* operator need to be updated accordingly
2530
*/
26-
enum datatype { Ushort = 20, Uint = 23 };
2731

2832
namespace cadence {
2933
namespace impl {
@@ -32,20 +36,22 @@ namespace native {
3236

3337
Tensor& cat_out(
3438
KernelRuntimeContext& ctx,
35-
exec_aten::ArrayRef<Tensor> tensors,
39+
ArrayRef<Tensor> tensors,
3640
int64_t dim,
3741
Tensor& out) {
3842
if (dim < 0) {
3943
dim += out.dim();
4044
}
4145

46+
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
47+
48+
#ifdef OP_ARG_CHECK
4249
ET_KERNEL_CHECK(
4350
ctx,
4451
torch::executor::check_cat_args(tensors, dim, out),
4552
InvalidArgument,
4653
out);
4754

48-
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
4955
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
5056
size_t expected_out_dim = 0;
5157
torch::executor::get_cat_out_target_size(
@@ -57,14 +63,28 @@ Tensor& cat_out(
5763
out, {expected_out_size, expected_out_dim}) == Error::Ok,
5864
InvalidArgument,
5965
out);
66+
#endif
67+
// Special handling when all inputs are 1D-empty tensors for aten
68+
// consistency In that case, just return an 1D-empty tensor without checking
69+
// dim
70+
bool all_1d_empty = true;
71+
for (size_t i = 0; i < tensors.size(); ++i) {
72+
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
73+
all_1d_empty = false;
74+
break;
75+
}
76+
}
77+
if (all_1d_empty) {
78+
return out;
79+
}
6080

6181
const signed char* inp_tensors[tensors.size()];
6282
const int* inp_tensors_shapes[tensors.size()];
6383

6484
int inp_shapes_size[tensors.size()];
6585

6686
int temp_sizes[tensors.size()][kTensorDimensionLimit];
67-
exec_aten::ArrayRef<Tensor::SizesType> temp_size;
87+
ArrayRef<Tensor::SizesType> temp_size;
6888

6989
for (int i = 0; i < tensors.size(); i++) {
7090
inp_tensors[i] = tensors[i].const_data_ptr<signed char>();
@@ -79,88 +99,32 @@ Tensor& cat_out(
7999

80100
signed char* out_data = out.mutable_data_ptr<signed char>();
81101

82-
const exec_aten::ArrayRef<Tensor::SizesType> out_size = out.sizes();
102+
const ArrayRef<Tensor::SizesType> out_size = out.sizes();
83103
int out_shapes[kTensorDimensionLimit];
84104
for (int i = 0; i < out_size.size(); i++) // output shapes
85105
{
86106
out_shapes[i] = out_size[i];
87107
}
88108

89-
if (out.scalar_type() == ScalarType::Int) {
90-
xa_nn_cat(
91-
out_data,
92-
out_shapes,
93-
inp_tensors,
94-
inp_tensors_shapes,
95-
inp_shapes_size[0],
96-
tensors.size(),
97-
(int)dim,
98-
sizeof(int));
99-
} else if (out.scalar_type() == ScalarType::Short) {
100-
xa_nn_cat(
101-
out_data,
102-
out_shapes,
103-
inp_tensors,
104-
inp_tensors_shapes,
105-
inp_shapes_size[0],
106-
tensors.size(),
107-
(int)dim,
108-
sizeof(short));
109-
} else if (out.scalar_type() == ScalarType::Char) {
110-
xa_nn_cat(
111-
out_data,
112-
out_shapes,
113-
inp_tensors,
114-
inp_tensors_shapes,
115-
inp_shapes_size[0],
116-
tensors.size(),
117-
(int)dim,
118-
sizeof(char));
119-
} else if (out.scalar_type() == (ScalarType)Uint) {
120-
xa_nn_cat(
121-
out_data,
122-
out_shapes,
123-
inp_tensors,
124-
inp_tensors_shapes,
125-
inp_shapes_size[0],
126-
tensors.size(),
127-
(int)dim,
128-
sizeof(int));
129-
} else if (out.scalar_type() == (ScalarType)Ushort) {
130-
xa_nn_cat(
109+
if ((out.scalar_type() == ScalarType::Int) ||
110+
(out.scalar_type() == ScalarType::Short) ||
111+
(out.scalar_type() == ScalarType::Char) ||
112+
(out.scalar_type() == ScalarType::UInt32) ||
113+
(out.scalar_type() == ScalarType::UInt16) ||
114+
(out.scalar_type() == ScalarType::Byte)) {
115+
XT_KERNEL_CHECK(
116+
ctx,
117+
out,
118+
xa_nn_cat,
131119
out_data,
132120
out_shapes,
133121
inp_tensors,
134122
inp_tensors_shapes,
135123
inp_shapes_size[0],
136124
tensors.size(),
137125
(int)dim,
138-
sizeof(short));
139-
} else if (out.scalar_type() == ScalarType::Byte) {
140-
xa_nn_cat(
141-
out_data,
142-
out_shapes,
143-
inp_tensors,
144-
inp_tensors_shapes,
145-
inp_shapes_size[0],
146-
tensors.size(),
147-
(int)dim,
148-
sizeof(char));
149-
126+
get_element_size(out.scalar_type()));
150127
} else {
151-
// Special handling when all inputs are 1D-empty tensors for aten
152-
// consistency In that case, just return an 1D-empty tensor without checking
153-
// dim
154-
bool all_1d_empty = true;
155-
for (size_t i = 0; i < tensors.size(); ++i) {
156-
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
157-
all_1d_empty = false;
158-
break;
159-
}
160-
}
161-
if (all_1d_empty) {
162-
return out;
163-
}
164128
const size_t outer = executorch::runtime::getLeadingDims(out, dim);
165129
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
166130
const size_t ninputs = tensors.size();

0 commit comments

Comments
 (0)