Skip to content

Commit 6ce9f52

Browse files
authored
t to z start ops | add dim order sanity check
Differential Revision: D59990127 Pull Request resolved: #4328
1 parent 28beeff commit 6ce9f52

File tree

9 files changed

+43
-0
lines changed

9 files changed

+43
-0
lines changed

kernels/portable/cpu/op_t_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ Tensor& t_copy_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
4747
return out;
4848
}
4949

50+
ET_KERNEL_CHECK(
51+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
52+
53+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
54+
5055
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
5156
size_t expected_out_dim = 0;
5257
get_transpose_out_target_size(in, 1, 0, expected_out_size, &expected_out_dim);

kernels/portable/cpu/op_to_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ Tensor& to_copy_out(
4646
InvalidArgument,
4747
out);
4848

49+
ET_KERNEL_CHECK(
50+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
51+
52+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
53+
4954
ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] {
5055
ET_SWITCH_REALHBBF16_TYPES(
5156
out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] {

kernels/portable/cpu/op_transpose_copy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ Tensor& transpose_copy_int_out(
5757
InvalidArgument,
5858
out);
5959

60+
ET_KERNEL_CHECK(
61+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
62+
6063
ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
6164
transpose_tensors<CTYPE>(in, dim0, dim1, out);
6265
});

kernels/portable/cpu/op_tril.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ Tensor& tril_out(
145145
InvalidArgument,
146146
out);
147147

148+
ET_KERNEL_CHECK(
149+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
150+
151+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
152+
148153
if (self.numel() == 0) {
149154
return out;
150155
}

kernels/portable/cpu/op_unbind_copy.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ void unbind_copy_int_out(
3636
ET_KERNEL_CHECK(
3737
ctx, check_unbind_copy_args(input, dim, out), InvalidArgument, );
3838

39+
for (int i = 0; i < out.size(); ++i) {
40+
ET_KERNEL_CHECK(
41+
ctx, tensors_have_same_dim_order(input, out[i]), InvalidArgument, );
42+
}
43+
44+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(input), InvalidArgument, );
45+
3946
if (input.numel() == 0) {
4047
return;
4148
}

kernels/portable/cpu/op_unsqueeze_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ Tensor& unsqueeze_copy_out(
3838
ET_KERNEL_CHECK(ctx, self.dim() + 1 == out.dim(), InvalidArgument, out);
3939
ET_KERNEL_CHECK(ctx, dim <= self.dim(), InvalidArgument, out);
4040

41+
ET_KERNEL_CHECK(
42+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
43+
44+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
45+
4146
for (size_t i = 0; i < out.dim(); ++i) {
4247
if (i < dim) {
4348
expected_output_size[i] = self.size(i);

kernels/portable/cpu/op_var.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ Tensor& var_out(
7474
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);
7575
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
7676

77+
ET_KERNEL_CHECK(
78+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
79+
80+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
81+
7782
ET_KERNEL_CHECK(
7883
ctx,
7984
resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok,

kernels/portable/cpu/op_view_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Tensor& view_copy_out(
4444
out,
4545
"Failed to resize output tensor.");
4646

47+
ET_KERNEL_CHECK(
48+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
49+
50+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
51+
4752
ET_KERNEL_CHECK(
4853
ctx, check_view_copy_args(self, size_int64_t, out), InvalidArgument, out);
4954

kernels/portable/cpu/op_where.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Tensor& where_out(
3535
InvalidArgument,
3636
out);
3737

38+
ET_KERNEL_CHECK(
39+
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
40+
3841
constexpr auto name = "where.self_out";
3942

4043
ET_CHECK_MSG(

0 commit comments

Comments
 (0)