Skip to content

Commit a0a8b64

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
a-start ops | add dim order regulation (#4330)
Summary: Pull Request resolved: #4330 This diff updates the sanity checks on operators starting with a for the dim order regulation. tracking table https://docs.google.com/spreadsheets/d/1Gttxkur8H6QnNfiCGfSAKwtBqdL6MSxn9eJ62bVYS_w/edit?gid=0#gid=0 Reviewed By: larryliu0820 Differential Revision: D59824508
1 parent bc56a97 commit a0a8b64

20 files changed

+207
-84
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ Tensor& add_out(
7979
out);
8080

8181
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
82+
ET_KERNEL_CHECK(
83+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
8284

8385
ScalarType a_type = a.scalar_type();
8486
ScalarType b_type = b.scalar_type();
@@ -131,6 +133,8 @@ Tensor& add_scalar_out(
131133
"Failed to resize output tensor.");
132134

133135
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
136+
ET_KERNEL_CHECK(
137+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
134138

135139
ScalarType a_type = a.scalar_type();
136140
ScalarType b_type = utils::get_scalar_dtype(b);

kernels/portable/cpu/op_addmm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ Tensor& addmm_out(
4545
ET_KERNEL_CHECK(
4646
ctx, tensor_is_broadcastable_to(in, out), InvalidArgument, out);
4747

48+
ET_KERNEL_CHECK(
49+
ctx,
50+
tensors_have_same_dim_order(in, mat1, mat2, out),
51+
InvalidArgument,
52+
out);
53+
54+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
55+
4856
ScalarType alpha_dtype = utils::get_scalar_dtype(alpha);
4957
ScalarType beta_dtype = utils::get_scalar_dtype(beta);
5058
ET_SWITCH_REAL_TYPES_AND(

kernels/portable/cpu/op_alias_copy.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Tensor& alias_copy_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
2828
"Failed to resize output tensor.");
2929

3030
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3133

3234
if (in.nbytes() > 0) {
3335
// Note that this check is important. It's valid for a tensor with numel 0

kernels/portable/cpu/op_allclose.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ Tensor& allclose_out(
104104
out.scalar_type() == ScalarType::Bool,
105105
"Out tensor must be type Bool; saw type %" PRId8,
106106
static_cast<int8_t>(out.scalar_type()));
107+
ET_CHECK_MSG(
108+
tensors_have_same_dim_order(self, other, out),
109+
"self, other and out tensors should have same dim order");
107110
ET_CHECK_MSG(
108111
out.numel() == 1,
109112
"Out tensor must be a single element; saw %zu elements",

kernels/portable/cpu/op_amax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& amax_out(
3939
InvalidArgument,
4040
out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
44+
4245
ET_SWITCH_REAL_TYPES_AND(
4346
Bool, in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
4447
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

kernels/portable/cpu/op_amin.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& amin_out(
3939
InvalidArgument,
4040
out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
44+
4245
ET_SWITCH_REAL_TYPES_AND(
4346
Bool, in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
4447
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

kernels/portable/cpu/op_any.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ Tensor& any_all_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
2222
ET_KERNEL_CHECK(
2323
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
2424

25+
ET_KERNEL_CHECK(
26+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
27+
2528
ScalarType in_type = in.scalar_type();
2629
ScalarType out_type = out.scalar_type();
2730
constexpr auto name = "any.all_out";
@@ -68,6 +71,9 @@ Tensor& any_dims_out(
6871
out);
6972
}
7073

74+
ET_KERNEL_CHECK(
75+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
76+
7177
ScalarType in_type = in.scalar_type();
7278
ScalarType out_type = out.scalar_type();
7379
constexpr auto name = "any.dims_out";
@@ -122,6 +128,9 @@ Tensor& any_out(
122128
InvalidArgument,
123129
out);
124130

131+
ET_KERNEL_CHECK(
132+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
133+
125134
ScalarType in_type = in.scalar_type();
126135
ScalarType out_type = out.scalar_type();
127136
constexpr auto name = "any.out";

kernels/portable/cpu/op_arange.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Tensor& arange_out(RuntimeContext& ctx, const Scalar& end, Tensor& out) {
2727
ET_KERNEL_CHECK(
2828
ctx, check_arange_args(0.0, end_val, 1.0, out), InvalidArgument, out);
2929

30+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
31+
3032
size_t size = static_cast<size_t>(std::ceil(end_val));
3133

3234
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
@@ -73,6 +75,8 @@ Tensor& arange_start_out(
7375
InvalidArgument,
7476
out);
7577

78+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
79+
7680
double size_d = (d_end - d_start) / d_step;
7781
size_t size = static_cast<size_t>(std::ceil(size_d));
7882

kernels/portable/cpu/op_argmax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ Tensor& argmax_out(
4040
InvalidArgument,
4141
out);
4242

43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45+
4346
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
4447
long* out_data = out.mutable_data_ptr<long>();
4548

kernels/portable/cpu/op_argmin.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ Tensor& argmin_out(
4040
InvalidArgument,
4141
out);
4242

43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45+
4346
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4447
long* out_data = out.mutable_data_ptr<long>();
4548

kernels/portable/cpu/op_as_strided_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Tensor& as_strided_copy_out(
3737
InvalidArgument,
3838
out);
3939

40+
ET_KERNEL_CHECK(
41+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
42+
43+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
44+
4045
if (in.numel() == 0) {
4146
return out;
4247
}

kernels/portable/cpu/op_atan2.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ atan2_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
2626
InvalidArgument,
2727
out);
2828

29+
ET_KERNEL_CHECK(
30+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
31+
2932
ScalarType a_type = a.scalar_type();
3033
ScalarType b_type = b.scalar_type();
3134
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/op_avg_pool2d.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Tensor& avg_pool2d_out(
4444
InvalidArgument,
4545
out);
4646

47+
ET_KERNEL_CHECK(
48+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
49+
50+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
51+
4752
size_t output_ndim = 0;
4853
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
4954
get_avg_pool2d_out_target_size(

kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ Tensor& binary_ufunc_realb_realb_to_realb_logical(
2727
InvalidArgument,
2828
out);
2929

30+
ET_KERNEL_CHECK(
31+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
32+
3033
ScalarType a_type = a.scalar_type();
3134
ScalarType b_type = b.scalar_type();
3235
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/pattern/unary_ufunc_realh.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Tensor& unary_ufunc_realh(
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
3535

36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
38+
3639
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
3740
apply_unary_map_fn(
3841
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ Tensor& unary_ufunc_realhb_to_bool(
3838
"Expected out tensor to have dtype Bool, but got %" PRId8 " instead.",
3939
static_cast<int8_t>(out.scalar_type()));
4040

41+
ET_KERNEL_CHECK(
42+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
43+
4144
const auto in_type = in.scalar_type();
4245

4346
ET_SWITCH_REALHB_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Tensor& unary_ufunc_realhb_to_floath(
3232
out,
3333
"Failed to resize output tensor.");
3434

35+
ET_KERNEL_CHECK(
36+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37+
3538
const auto in_type = in.scalar_type();
3639
const auto out_type = out.scalar_type();
3740

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,33 +1153,80 @@ bool tensor_has_valid_dim_order(exec_aten::Tensor t);
11531153
bool tensor_is_default_or_channels_last_dim_order(exec_aten::Tensor t);
11541154

11551155
/**
1156-
* Asserts that two tensors have the same dim_order
1156+
* Checks whether a tensor has the default dimension order.
1157+
* Logs an error message if the tensor does not meet the expected criteria.
11571158
*
1158-
* Note that this macro only tests dim order, but not others like actual data,
1159-
* sizes, etc. Also this macro does not support ATen mode since we do not
1160-
* support dim order in ATen mode.
1159+
* @param t The tensor to check the dimension order of.
1160+
* @return True if the tensor has the default dimension order, false otherwise.
1161+
*/
1162+
bool tensor_is_default_dim_order(exec_aten::Tensor t);
1163+
1164+
/**
1165+
* Checks whether a tensor has the channels last dimension order.
1166+
* Logs an error message if the tensor does not meet the expected criteria.
11611167
*
1162-
* TODO(T183094318): Add dim order and related function support for ATen mode.
1168+
* @param t The tensor to check the dimension order of.
1169+
* @return True if the tensor has the channels last dimension order, false
1170+
* otherwise.
11631171
*/
1172+
bool tensor_is_channels_last_dim_order(exec_aten::Tensor t);
11641173

1174+
/**
1175+
* Asserts that four tensors have the same dim_order
1176+
*
1177+
* Note that this macro only tests dim order, but not others like actual data,
1178+
* sizes, etc.
1179+
*
1180+
*/
11651181
bool tensors_have_same_dim_order(
1182+
const exec_aten::ArrayRef<exec_aten::Tensor> tensor_list);
1183+
1184+
/**
1185+
* Asserts that two tensors have the same dim_order
1186+
*
1187+
* Note that this macro only tests dim order, but not others like actual data,
1188+
* sizes, etc.
1189+
*/
1190+
1191+
inline bool tensors_have_same_dim_order(
11661192
const exec_aten::Tensor& a,
1167-
const exec_aten::Tensor& b);
1193+
const exec_aten::Tensor& b) {
1194+
exec_aten::Tensor tensor_list[2] = {a, b};
1195+
return tensors_have_same_dim_order(tensor_list);
1196+
}
11681197

11691198
/**
11701199
* Asserts that three tensors have the same dim_order
11711200
*
11721201
* Note that this macro only tests dim order, but not others like actual data,
1173-
* sizes, etc. Also this macro does not support ATen mode since we do not
1174-
* support dim order in ATen mode.
1202+
* sizes, etc.
11751203
*
1176-
* TODO(T183094318): Add dim order and related function support for ATen mode.
11771204
*/
11781205

1179-
bool tensors_have_same_dim_order(
1206+
inline bool tensors_have_same_dim_order(
11801207
const exec_aten::Tensor& a,
11811208
const exec_aten::Tensor& b,
1182-
const exec_aten::Tensor& c);
1209+
const exec_aten::Tensor& c) {
1210+
exec_aten::Tensor tensor_list[3] = {a, b, c};
1211+
return tensors_have_same_dim_order(tensor_list);
1212+
}
1213+
1214+
/**
1215+
* Asserts that four tensors have the same dim_order
1216+
*
1217+
* Note that this macro only tests dim order, but not others like actual data,
1218+
* sizes, etc.
1219+
*
1220+
*/
1221+
1222+
inline bool tensors_have_same_dim_order(
1223+
const exec_aten::Tensor& a,
1224+
const exec_aten::Tensor& b,
1225+
const exec_aten::Tensor& c,
1226+
const exec_aten::Tensor& d) {
1227+
exec_aten::Tensor tensor_list[4] = {a, b, c, d};
1228+
return tensors_have_same_dim_order(tensor_list);
1229+
}
11831230

11841231
/**
11851232
* Given an n-dimensional coordinate array and an array of tensor strides,
@@ -1232,6 +1279,7 @@ using ::executorch::runtime::tensor_is_bits_type;
12321279
using ::executorch::runtime::tensor_is_bool_type;
12331280
using ::executorch::runtime::tensor_is_complex_type;
12341281
using ::executorch::runtime::tensor_is_contiguous;
1282+
using ::executorch::runtime::tensor_is_default_dim_order;
12351283
using ::executorch::runtime::tensor_is_default_or_channels_last_dim_order;
12361284
using ::executorch::runtime::tensor_is_floating_type;
12371285
using ::executorch::runtime::tensor_is_integral_type;

0 commit comments

Comments
 (0)