Skip to content

Commit e36f064

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
b&c start ops | add dim order sanity check
Differential Revision: D59824515
1 parent e3fbd8d commit e36f064

13 files changed

+59
-0
lines changed

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Tensor& bitwise_and_Tensor_out(
3232
InvalidArgument,
3333
out);
3434

35+
ET_KERNEL_CHECK(
36+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
37+
3538
ScalarType a_type = a.scalar_type();
3639
ScalarType b_type = b.scalar_type();
3740
ScalarType common_type = promoteTypes(a_type, b_type);
@@ -82,6 +85,9 @@ Tensor& bitwise_and_Scalar_out(
8285
out,
8386
"Failed to resize output tensor.");
8487

88+
ET_KERNEL_CHECK(
89+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
90+
8591
ScalarType a_type = a.scalar_type();
8692
ScalarType b_type = utils::get_scalar_dtype(b);
8793
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);

kernels/portable/cpu/op_bitwise_not.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ Tensor& bitwise_not_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333
"Failed to resize output tensor.");
3434

3535
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3638

3739
if (in.scalar_type() == exec_aten::ScalarType::Bool) {
3840
apply_unary_map_fn(

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Tensor& bitwise_or_Tensor_out(
3232
InvalidArgument,
3333
out);
3434

35+
ET_KERNEL_CHECK(
36+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
37+
3538
ScalarType a_type = a.scalar_type();
3639
ScalarType b_type = b.scalar_type();
3740
ScalarType common_type = promoteTypes(a_type, b_type);
@@ -74,6 +77,9 @@ Tensor& bitwise_or_Scalar_out(
7477
Tensor& out) {
7578
(void)ctx;
7679

80+
ET_KERNEL_CHECK(
81+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
82+
7783
// Resize for dynamic shape
7884
ET_KERNEL_CHECK_MSG(
7985
ctx,

kernels/portable/cpu/op_bitwise_xor.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Tensor& bitwise_xor_Tensor_out(
3232
InvalidArgument,
3333
out);
3434

35+
ET_KERNEL_CHECK(
36+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
37+
3538
ScalarType a_type = a.scalar_type();
3639
ScalarType b_type = b.scalar_type();
3740
ScalarType common_type = promoteTypes(a_type, b_type);
@@ -82,6 +85,9 @@ Tensor& bitwise_xor_Scalar_out(
8285
out,
8386
"Failed to resize output tensor.");
8487

88+
ET_KERNEL_CHECK(
89+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
90+
8591
ScalarType a_type = a.scalar_type();
8692
ScalarType b_type = utils::get_scalar_dtype(b);
8793
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);

kernels/portable/cpu/op_bmm.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ Tensor& bmm_out(
2323
Tensor& out) {
2424
ET_KERNEL_CHECK(ctx, check_bmm_args(in, mat2, out), InvalidArgument, out);
2525

26+
ET_KERNEL_CHECK(
27+
ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out);
28+
29+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
30+
2631
size_t output_ndim = 0;
2732
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
2833
get_bmm_out_target_size(in, mat2, output_sizes, &output_ndim);

kernels/portable/cpu/op_cdist_forward.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ Tensor& _cdist_forward_out(
124124
Tensor& out) {
125125
(void)ctx;
126126

127+
ET_KERNEL_CHECK(
128+
ctx, tensors_have_same_dim_order(x1, x2, out), InvalidArgument, out);
129+
130+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(x1), InvalidArgument, out);
131+
127132
ET_KERNEL_CHECK(
128133
ctx,
129134
check_cdist_args(x1, x2, p, compute_mode, out),

kernels/portable/cpu/op_clamp.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ Tensor& clamp_out(
8383
out,
8484
"Failed to resize output tensor.");
8585

86+
ET_KERNEL_CHECK(
87+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
88+
8689
ScalarType in_type = in.scalar_type();
8790
ScalarType min_type = in_type;
8891
ScalarType max_type = in_type;
@@ -182,6 +185,12 @@ Tensor& clamp_tensor_out(
182185
const Tensor& min = has_min ? min_opt.value() : in;
183186
const Tensor& max = has_max ? max_opt.value() : in;
184187

188+
ET_KERNEL_CHECK(
189+
ctx,
190+
tensors_have_same_dim_order(in, min, max, out),
191+
InvalidArgument,
192+
out);
193+
185194
ET_KERNEL_CHECK(
186195
ctx,
187196
resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,

kernels/portable/cpu/op_clone.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ Tensor& clone_out(
3838
InvalidArgument,
3939
out);
4040

41+
ET_KERNEL_CHECK(
42+
context, tensors_have_same_dim_order(self, out), InvalidArgument, out);
43+
4144
// Right now we only focus on contiguous memory, memory_format shall always
4245
// either a nullopt or exec::aten::MemoryFormat::Contiguous
4346
ET_KERNEL_CHECK(

kernels/portable/cpu/op_constant_pad_nd.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ Tensor& constant_pad_nd_out(
170170
ET_KERNEL_CHECK(
171171
ctx, check_constant_pad_args(in, pad, value, out), InvalidArgument, out);
172172

173+
ET_KERNEL_CHECK(
174+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
175+
173176
// resize out tensor for dynamic shapes
174177
ET_KERNEL_CHECK_MSG(
175178
ctx,

kernels/portable/cpu/op_convolution.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ Tensor& convolution_out(
365365
InvalidArgument,
366366
out);
367367

368+
ET_KERNEL_CHECK(
369+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
370+
368371
size_t output_ndim = 0;
369372
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
370373
get_convolution_out_target_size(

kernels/portable/cpu/op_copy.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& copy_out(
3939
ET_KERNEL_CHECK(
4040
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
44+
4245
ScalarType in_type = in.scalar_type();
4346
ScalarType src_type = src.scalar_type();
4447

@@ -66,6 +69,9 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
6669
ET_KERNEL_CHECK(
6770
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
6871

72+
ET_KERNEL_CHECK(
73+
ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
74+
6975
ScalarType in_type = in.scalar_type();
7076
ScalarType src_type = src.scalar_type();
7177

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ Tensor& cumsum_out(
9393
InvalidArgument,
9494
out);
9595

96+
ET_KERNEL_CHECK(
97+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
98+
9699
ET_KERNEL_CHECK(
97100
ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);
98101

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ bool check_cat_args(
9595
ET_LOG_AND_RETURN_IF_FALSE(
9696
canCast(tensors[i].scalar_type(), out.scalar_type()));
9797

98+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dim_order(tensors[i], out));
99+
98100
// Empty tensors have no shape constraints.
99101
if (tensors[i].numel() == 0) {
100102
continue;

0 commit comments

Comments
 (0)