Skip to content

Commit 5df4cef

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
d to g start ops | add dim order sanity check
Differential Revision: D59846689
1 parent 25937af commit 5df4cef

15 files changed

+77
-0
lines changed

kernels/portable/cpu/op_detach_copy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Tensor& detach_copy_out(RuntimeContext& ctx, const Tensor& self, Tensor& out) {
3333
out,
3434
"Failed to resize output tensor.");
3535

36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
38+
3639
ET_KERNEL_CHECK(
3740
ctx, tensors_have_same_shape_and_dtype(self, out), InvalidArgument, out);
3841

kernels/portable/cpu/op_diagonal_copy.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ Tensor& diagonal_copy_out(
7373
ET_KERNEL_CHECK(
7474
ctx, check_diagonal_copy_args(in, dim1, dim2, out), InvalidArgument, out);
7575

76+
ET_KERNEL_CHECK(
77+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
78+
79+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
80+
7681
if (dim1 < 0) {
7782
dim1 += nonzero_dim(in);
7883
}

kernels/portable/cpu/op_div.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
4141
InvalidArgument,
4242
out);
4343

44+
ET_KERNEL_CHECK(
45+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
46+
4447
ScalarType a_type = a.scalar_type();
4548
ScalarType b_type = b.scalar_type();
4649

@@ -97,6 +100,9 @@ Tensor& div_out_mode(
97100
InvalidArgument,
98101
out);
99102

103+
ET_KERNEL_CHECK(
104+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
105+
100106
ScalarType a_type = a.scalar_type();
101107
ScalarType b_type = b.scalar_type();
102108
ScalarType common_type = get_compute_type(a_type, b_type);
@@ -159,6 +165,9 @@ Tensor& div_scalar_out(
159165
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
160166
ScalarType out_type = out.scalar_type();
161167

168+
ET_KERNEL_CHECK(
169+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
170+
162171
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
163172

164173
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {

kernels/portable/cpu/op_embedding.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ Tensor& embedding_out(
102102
out.size(1),
103103
weight.size(1));
104104

105+
ET_KERNEL_CHECK(
106+
ctx,
107+
tensors_have_same_dim_order(weight, indices, out),
108+
InvalidArgument,
109+
out);
110+
111+
ET_KERNEL_CHECK(
112+
ctx, tensor_is_default_dim_order(weight), InvalidArgument, out);
113+
105114
ScalarType ix_type = indices.scalar_type();
106115
ET_CHECK_MSG(
107116
ix_type == ScalarType::Long || ix_type == ScalarType::Int,

kernels/portable/cpu/op_eq.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Tensor& eq_tensor_out(
3434
ScalarType b_type = b.scalar_type();
3535
ScalarType out_type = out.scalar_type();
3636

37+
ET_KERNEL_CHECK(
38+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39+
3740
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
3841
ET_SWITCH_REAL_TYPES_AND(
3942
Bool, b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
@@ -80,6 +83,9 @@ Tensor& eq_scalar_out(
8083
ScalarType b_type = utils::get_scalar_dtype(b);
8184
ScalarType out_type = out.scalar_type();
8285

86+
ET_KERNEL_CHECK(
87+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
88+
8389
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
8490
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
8591
using CTYPE_IN =

kernels/portable/cpu/op_expand_copy.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ Tensor& expand_copy_out(
8585
InvalidArgument,
8686
out);
8787

88+
ET_KERNEL_CHECK(
89+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
90+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
91+
8892
// Holds the result of expand_sizes converted to repeat sizes
8993
int64_t repeats[kTensorDimensionLimit];
9094
const auto repeats_size{map_expand_to_repeats(

kernels/portable/cpu/op_fill.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Tensor& fill_scalar_out(
3131

3232
ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out);
3333

34+
ET_KERNEL_CHECK(
35+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
36+
3437
// Resize for dynamic shape
3538
ET_KERNEL_CHECK_MSG(
3639
ctx,
@@ -67,6 +70,9 @@ Tensor& fill_tensor_out(
6770
// Assert `b` must be a scalar tensor.
6871
ET_KERNEL_CHECK(ctx, tensor_is_scalar(b), InvalidArgument, out);
6972

73+
ET_KERNEL_CHECK(
74+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
75+
7076
ScalarType a_type = a.scalar_type();
7177
ScalarType b_type = b.scalar_type();
7278
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/op_flip.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ flip_out(RuntimeContext& ctx, const Tensor& in, IntArrayRef dims, Tensor& out) {
4545
ET_KERNEL_CHECK(
4646
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
4747

48+
ET_KERNEL_CHECK(
49+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
50+
4851
ET_KERNEL_CHECK(ctx, check_flip_args(in, dims, out), InvalidArgument, out);
4952

5053
bool flip_dim_data[kTensorDimensionLimit];

kernels/portable/cpu/op_floor_divide.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ Tensor& floor_divide_out(
8787

8888
ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out);
8989

90+
ET_KERNEL_CHECK(
91+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
92+
9093
ScalarType a_type = a.scalar_type();
9194
ScalarType b_type = b.scalar_type();
9295
ScalarType common_type = promoteTypes(a_type, b_type);

kernels/portable/cpu/op_fmod.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ Tensor& fmod_Tensor_out(
8585
InvalidArgument,
8686
out);
8787

88+
ET_KERNEL_CHECK(
89+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
90+
8891
ScalarType a_type = a.scalar_type();
8992
ScalarType b_type = b.scalar_type();
9093
ScalarType common_type = promoteTypes(a_type, b_type);
@@ -139,6 +142,9 @@ Tensor& fmod_Scalar_out(
139142
out,
140143
"Failed to resize output tensor.");
141144

145+
ET_KERNEL_CHECK(
146+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
147+
142148
ScalarType a_type = a.scalar_type();
143149
ScalarType b_type = utils::get_scalar_dtype(b);
144150
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);

kernels/portable/cpu/op_full_like.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ Tensor& full_like_out(
3434
"memory_format must be contiguous");
3535
}
3636

37+
ET_KERNEL_CHECK(
38+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
39+
40+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
41+
3742
// Resize for dynamic shape
3843
ET_KERNEL_CHECK_MSG(
3944
ctx,

kernels/portable/cpu/op_ge.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Tensor& ge_tensor_out(
3131
InvalidArgument,
3232
out);
3333

34+
ET_KERNEL_CHECK(
35+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
36+
3437
ScalarType a_type = a.scalar_type();
3538
ScalarType b_type = b.scalar_type();
3639
ScalarType out_type = out.scalar_type();
@@ -77,6 +80,9 @@ Tensor& ge_scalar_out(
7780
out,
7881
"Failed to resize output tensor.");
7982

83+
ET_KERNEL_CHECK(
84+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85+
8086
ScalarType a_type = a.scalar_type();
8187
ScalarType b_type = utils::get_scalar_dtype(b);
8288
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);

kernels/portable/cpu/op_gelu.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Tensor& gelu_out(
3434
ET_KERNEL_CHECK(
3535
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
3636

37+
ET_KERNEL_CHECK(
38+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
39+
3740
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
3841
if (approximate == "tanh") {
3942
apply_unary_map_fn(

kernels/portable/cpu/op_glu.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ glu_out(RuntimeContext& ctx, const Tensor& self, int64_t dim, Tensor& out) {
144144
ET_KERNEL_CHECK(
145145
ctx, resize_glu_out(self, dim, out) == Error::Ok, InvalidArgument, out);
146146

147+
ET_KERNEL_CHECK(
148+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
149+
147150
ET_KERNEL_CHECK(ctx, check_glu_args(self, dim, out), InvalidArgument, out);
148151

149152
const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim;

kernels/portable/cpu/op_gt.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Tensor& gt_tensor_out(
3131
InvalidArgument,
3232
out);
3333

34+
ET_KERNEL_CHECK(
35+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
36+
3437
ScalarType a_type = a.scalar_type();
3538
ScalarType b_type = b.scalar_type();
3639
ScalarType out_type = out.scalar_type();
@@ -77,6 +80,9 @@ Tensor& gt_scalar_out(
7780
out,
7881
"Failed to resize output tensor.");
7982

83+
ET_KERNEL_CHECK(
84+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85+
8086
ScalarType a_type = a.scalar_type();
8187
ScalarType b_type = utils::get_scalar_dtype(b);
8288
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);

0 commit comments

Comments
 (0)