Skip to content

Commit e3fbd8d

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
a-start ops | add dim order regulation
Differential Revision: D59824508
1 parent 5fb15f7 commit e3fbd8d

20 files changed

+200
-7
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ Tensor& add_out(
7979
out);
8080

8181
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
82+
ET_KERNEL_CHECK(
83+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
8284

8385
ScalarType a_type = a.scalar_type();
8486
ScalarType b_type = b.scalar_type();
@@ -131,6 +133,8 @@ Tensor& add_scalar_out(
131133
"Failed to resize output tensor.");
132134

133135
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
136+
ET_KERNEL_CHECK(
137+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
134138

135139
ScalarType a_type = a.scalar_type();
136140
ScalarType b_type = utils::get_scalar_dtype(b);

kernels/portable/cpu/op_addmm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ Tensor& addmm_out(
4545
ET_KERNEL_CHECK(
4646
ctx, tensor_is_broadcastable_to(in, out), InvalidArgument, out);
4747

48+
ET_KERNEL_CHECK(
49+
ctx,
50+
tensors_have_same_dim_order(in, mat1, mat2, out),
51+
InvalidArgument,
52+
out);
53+
54+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
55+
4856
ScalarType alpha_dtype = utils::get_scalar_dtype(alpha);
4957
ScalarType beta_dtype = utils::get_scalar_dtype(beta);
5058
ET_SWITCH_REAL_TYPES_AND(

kernels/portable/cpu/op_alias_copy.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Tensor& alias_copy_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
2828
"Failed to resize output tensor.");
2929

3030
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3133

3234
if (in.nbytes() > 0) {
3335
// Note that this check is important. It's valid for a tensor with numel 0

kernels/portable/cpu/op_allclose.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ Tensor& allclose_out(
104104
out.scalar_type() == ScalarType::Bool,
105105
"Out tensor must be type Bool; saw type %" PRId8,
106106
static_cast<int8_t>(out.scalar_type()));
107+
ET_CHECK_MSG(
108+
tensors_have_same_dim_order(self, other, out),
109+
"self, other and out tensors should have same dim order");
107110
ET_CHECK_MSG(
108111
out.numel() == 1,
109112
"Out tensor must be a single element; saw %zu elements",

kernels/portable/cpu/op_amax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& amax_out(
3939
InvalidArgument,
4040
out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
44+
4245
ET_SWITCH_REAL_TYPES_AND(
4346
Bool, in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
4447
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

kernels/portable/cpu/op_amin.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& amin_out(
3939
InvalidArgument,
4040
out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
44+
4245
ET_SWITCH_REAL_TYPES_AND(
4346
Bool, in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
4447
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

kernels/portable/cpu/op_any.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ Tensor& any_all_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
2222
ET_KERNEL_CHECK(
2323
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
2424

25+
ET_KERNEL_CHECK(
26+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
27+
2528
ScalarType in_type = in.scalar_type();
2629
ScalarType out_type = out.scalar_type();
2730
constexpr auto name = "any.all_out";
@@ -68,6 +71,9 @@ Tensor& any_dims_out(
6871
out);
6972
}
7073

74+
ET_KERNEL_CHECK(
75+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
76+
7177
ScalarType in_type = in.scalar_type();
7278
ScalarType out_type = out.scalar_type();
7379
constexpr auto name = "any.dims_out";
@@ -122,6 +128,9 @@ Tensor& any_out(
122128
InvalidArgument,
123129
out);
124130

131+
ET_KERNEL_CHECK(
132+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
133+
125134
ScalarType in_type = in.scalar_type();
126135
ScalarType out_type = out.scalar_type();
127136
constexpr auto name = "any.out";

kernels/portable/cpu/op_arange.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Tensor& arange_out(RuntimeContext& ctx, const Scalar& end, Tensor& out) {
2727
ET_KERNEL_CHECK(
2828
ctx, check_arange_args(0.0, end_val, 1.0, out), InvalidArgument, out);
2929

30+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
31+
3032
size_t size = static_cast<size_t>(std::ceil(end_val));
3133

3234
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
@@ -73,6 +75,8 @@ Tensor& arange_start_out(
7375
InvalidArgument,
7476
out);
7577

78+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
79+
7680
double size_d = (d_end - d_start) / d_step;
7781
size_t size = static_cast<size_t>(std::ceil(size_d));
7882

kernels/portable/cpu/op_argmax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ Tensor& argmax_out(
4040
InvalidArgument,
4141
out);
4242

43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45+
4346
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
4447
long* out_data = out.mutable_data_ptr<long>();
4548

kernels/portable/cpu/op_argmin.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ Tensor& argmin_out(
4040
InvalidArgument,
4141
out);
4242

43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45+
4346
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4447
long* out_data = out.mutable_data_ptr<long>();
4548

kernels/portable/cpu/op_as_strided_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Tensor& as_strided_copy_out(
3737
InvalidArgument,
3838
out);
3939

40+
ET_KERNEL_CHECK(
41+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
42+
43+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
44+
4045
if (in.numel() == 0) {
4146
return out;
4247
}

kernels/portable/cpu/op_atan2.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ atan2_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
2626
InvalidArgument,
2727
out);
2828

29+
ET_KERNEL_CHECK(
30+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
31+
2932
ScalarType a_type = a.scalar_type();
3033
ScalarType b_type = b.scalar_type();
3134
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/op_avg_pool2d.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Tensor& avg_pool2d_out(
4444
InvalidArgument,
4545
out);
4646

47+
ET_KERNEL_CHECK(
48+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
49+
50+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
51+
4752
size_t output_ndim = 0;
4853
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
4954
get_avg_pool2d_out_target_size(

kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ Tensor& binary_ufunc_realb_realb_to_realb_logical(
2727
InvalidArgument,
2828
out);
2929

30+
ET_KERNEL_CHECK(
31+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
32+
3033
ScalarType a_type = a.scalar_type();
3134
ScalarType b_type = b.scalar_type();
3235
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/pattern/unary_ufunc_realh.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Tensor& unary_ufunc_realh(
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
3535

36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
38+
3639
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
3740
apply_unary_map_fn(
3841
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ Tensor& unary_ufunc_realhb_to_bool(
3838
"Expected out tensor to have dtype Bool, but got %" PRId8 " instead.",
3939
static_cast<int8_t>(out.scalar_type()));
4040

41+
ET_KERNEL_CHECK(
42+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
43+
4144
const auto in_type = in.scalar_type();
4245

4346
ET_SWITCH_REALHB_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Tensor& unary_ufunc_realhb_to_floath(
3232
out,
3333
"Failed to resize output tensor.");
3434

35+
ET_KERNEL_CHECK(
36+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37+
3538
const auto in_type = in.scalar_type();
3639
const auto out_type = out.scalar_type();
3740

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,14 +1154,30 @@ bool tensor_has_valid_dim_order(exec_aten::Tensor t);
11541154
*/
11551155
bool tensor_is_default_or_channels_last_dim_order(exec_aten::Tensor t);
11561156

1157+
/**
1158+
* Checks whether a tensor has the default dimension order.
1159+
* Logs an error message if the tensor does not meet the expected criteria.
1160+
*
1161+
* @param t The tensor to check the dimension order of.
1162+
* @return True if the tensor has the default dimension order, false otherwise.
1163+
*/
1164+
bool tensor_is_default_dim_order(exec_aten::Tensor t);
1165+
1166+
/**
1167+
* Checks whether a tensor has the channels last dimension order.
1168+
* Logs an error message if the tensor does not meet the expected criteria.
1169+
*
1170+
* @param t The tensor to check the dimension order of.
1171+
* @return True if the tensor has the channels last dimension order, false
1172+
* otherwise.
1173+
*/
1174+
bool tensor_is_channels_last_dim_order(exec_aten::Tensor t);
1175+
11571176
/**
11581177
* Asserts that two tensors have the same dim_order
11591178
*
11601179
* Note that this macro only tests dim order, but not others like actual data,
1161-
* sizes, etc. Also this macro does not support ATen mode since we do not
1162-
* support dim order in ATen mode.
1163-
*
1164-
* TODO(T183094318): Add dim order and related function support for ATen mode.
1180+
* sizes, etc.
11651181
*/
11661182

11671183
bool tensors_have_same_dim_order(
@@ -1172,17 +1188,28 @@ bool tensors_have_same_dim_order(
11721188
* Asserts that three tensors have the same dim_order
11731189
*
11741190
* Note that this macro only tests dim order, but not others like actual data,
1175-
* sizes, etc. Also this macro does not support ATen mode since we do not
1176-
* support dim order in ATen mode.
1191+
* sizes, etc.
11771192
*
1178-
* TODO(T183094318): Add dim order and related function support for ATen mode.
11791193
*/
11801194

11811195
bool tensors_have_same_dim_order(
11821196
const exec_aten::Tensor& a,
11831197
const exec_aten::Tensor& b,
11841198
const exec_aten::Tensor& c);
11851199

1200+
/**
1201+
* Asserts that four tensors have the same dim_order
1202+
*
1203+
* Note that this macro only tests dim order, but not others like actual data,
1204+
* sizes, etc.
1205+
*
1206+
*/
1207+
bool tensors_have_same_dim_order(
1208+
const exec_aten::Tensor& a,
1209+
const exec_aten::Tensor& b,
1210+
const exec_aten::Tensor& c,
1211+
const exec_aten::Tensor& d);
1212+
11861213
/**
11871214
* Given an n-dimensional coordinate array and an array of tensor strides,
11881215
* calculates the linear index that can be used to retrieve the value at the

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,45 @@ bool tensors_have_same_dim_order(
135135
return true;
136136
}
137137

138+
bool tensors_have_same_dim_order(
139+
const exec_aten::Tensor& a,
140+
const exec_aten::Tensor& b,
141+
const exec_aten::Tensor& c,
142+
const exec_aten::Tensor& d) {
143+
exec_aten::DimOrderType a_dim_order[kTensorDimensionLimit];
144+
exec_aten::DimOrderType b_dim_order[kTensorDimensionLimit];
145+
exec_aten::DimOrderType c_dim_order[kTensorDimensionLimit];
146+
exec_aten::DimOrderType d_dim_order[kTensorDimensionLimit];
147+
148+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
149+
get_dim_order(a, a_dim_order, a.dim()) == Error::Ok,
150+
"Failed to retrieve dim order from first input tensor!");
151+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
152+
get_dim_order(b, b_dim_order, b.dim()) == Error::Ok,
153+
"Failed to retrieve dim order from second input tensor!");
154+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
155+
get_dim_order(c, c_dim_order, c.dim()) == Error::Ok,
156+
"Failed to retrieve dim order from third input tensor!");
157+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
158+
get_dim_order(d, d_dim_order, d.dim()) == Error::Ok,
159+
"Failed to retrieve dim order from fourth input tensor!");
160+
161+
bool all_contiguous = is_contiguous_dim_order(a_dim_order, a.dim()) &&
162+
is_contiguous_dim_order(b_dim_order, b.dim()) &&
163+
is_contiguous_dim_order(c_dim_order, c.dim()) &&
164+
is_contiguous_dim_order(d_dim_order, c.dim());
165+
166+
bool all_channels_last = is_channels_last_dim_order(a_dim_order, a.dim()) &&
167+
is_channels_last_dim_order(b_dim_order, b.dim()) &&
168+
is_channels_last_dim_order(c_dim_order, c.dim()) &&
169+
is_contiguous_dim_order(d_dim_order, d.dim());
170+
171+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
172+
all_contiguous || all_channels_last,
173+
"Three input tensors have different dim orders");
174+
return true;
175+
}
176+
138177
namespace internal {
139178

140179
Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {

0 commit comments

Comments
 (0)