Skip to content

Commit 7b27714

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
m to p
Differential Revision: D59984020
1 parent 53436a1 commit 7b27714

16 files changed

+157
-0
lines changed

kernels/portable/cpu/op_masked_fill.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Tensor& masked_fill_scalar_out(
3939
InvalidArgument,
4040
out);
4141

42+
ET_KERNEL_CHECK(
43+
ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
44+
4245
ET_SWITCH_REAL_TYPES_AND(
4346
Bool, in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() {
4447
ET_SWITCH_REAL_TYPES_AND(

kernels/portable/cpu/op_max.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ std::tuple<Tensor&, Tensor&> max_out(
4949
InvalidArgument,
5050
(std::tuple<Tensor&, Tensor&>({max, max_indices})));
5151

52+
ET_KERNEL_CHECK(
53+
ctx,
54+
tensors_have_same_dim_order(in, max),
55+
InvalidArgument,
56+
(std::tuple<Tensor&, Tensor&>({max, max_indices})));
57+
58+
ET_KERNEL_CHECK(
59+
ctx,
60+
tensor_is_default_dim_order(max_indices),
61+
InvalidArgument,
62+
(std::tuple<Tensor&, Tensor&>({max, max_indices})));
63+
64+
ET_KERNEL_CHECK(
65+
ctx,
66+
tensor_is_default_dim_order(in),
67+
InvalidArgument,
68+
(std::tuple<Tensor&, Tensor&>({max, max_indices})));
69+
5270
dim = dim < 0 ? dim + in.dim() : dim;
5371

5472
ET_SWITCH_REAL_TYPES_AND(

kernels/portable/cpu/op_maximum.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ Tensor& maximum_out(
7575
InvalidArgument,
7676
out);
7777

78+
ET_KERNEL_CHECK(
79+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80+
7881
ScalarType a_type = a.scalar_type();
7982
ScalarType b_type = b.scalar_type();
8083
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/portable/cpu/op_mean.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ Tensor& mean_dim_out(
3333
InvalidArgument,
3434
out);
3535

36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
38+
39+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
40+
3641
ET_KERNEL_CHECK(
3742
ctx,
3843
resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok,

kernels/portable/cpu/op_min.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ std::tuple<Tensor&, Tensor&> min_out(
4949
InvalidArgument,
5050
(std::tuple<Tensor&, Tensor&>({min, min_indices})));
5151

52+
ET_KERNEL_CHECK(
53+
ctx,
54+
tensors_have_same_dim_order(in, min),
55+
InvalidArgument,
56+
(std::tuple<Tensor&, Tensor&>({min, min_indices})));
57+
58+
ET_KERNEL_CHECK(
59+
ctx,
60+
tensor_is_default_dim_order(min_indices),
61+
InvalidArgument,
62+
(std::tuple<Tensor&, Tensor&>({min, min_indices})));
63+
64+
ET_KERNEL_CHECK(
65+
ctx,
66+
tensor_is_default_dim_order(in),
67+
InvalidArgument,
68+
(std::tuple<Tensor&, Tensor&>({min, min_indices})));
69+
5270
dim = dim < 0 ? dim + in.dim() : dim;
5371

5472
ET_SWITCH_REAL_TYPES_AND(

kernels/portable/cpu/op_minimum.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ Tensor& minimum_out(
7575
InvalidArgument,
7676
out);
7777

78+
ET_KERNEL_CHECK(
79+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80+
7881
ScalarType a_type = a.scalar_type();
7982
ScalarType b_type = b.scalar_type();
8083
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/portable/cpu/op_mm.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) {
2929
InvalidArgument,
3030
out);
3131

32+
ET_KERNEL_CHECK(
33+
ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out);
34+
35+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
36+
3237
ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
3338
size_t m = in.size(0);
3439
size_t n = in.size(1);

kernels/portable/cpu/op_mul.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
7272

7373
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
7474

75+
ET_KERNEL_CHECK(
76+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
77+
7578
ScalarType a_type = a.scalar_type();
7679
ScalarType b_type = b.scalar_type();
7780
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
@@ -113,6 +116,9 @@ Tensor& mul_scalar_out(
113116
out,
114117
"Failed to resize output tensor.");
115118

119+
ET_KERNEL_CHECK(
120+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
121+
116122
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
117123

118124
ScalarType a_type = a.scalar_type();

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
7373
InvalidArgument,
7474
ret_val);
7575

76+
ET_KERNEL_CHECK(
77+
ctx,
78+
tensors_have_same_dim_order(in, out, mean_out, invstd_out),
79+
InvalidArgument,
80+
ret_val);
81+
82+
if (weight.has_value()) {
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
tensors_have_same_dim_order(in, weight.value()),
86+
InvalidArgument,
87+
ret_val);
88+
}
89+
90+
if (bias.has_value()) {
91+
ET_KERNEL_CHECK(
92+
ctx,
93+
tensors_have_same_dim_order(in, bias.value()),
94+
InvalidArgument,
95+
ret_val);
96+
}
97+
7698
size_t C_dim = in.dim() >= 1 ? 1 : 0;
7799
size_t C = in.size(C_dim);
78100
size_t outer = getLeadingDims(in, C_dim);

kernels/portable/cpu/op_native_group_norm.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,31 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
158158
InvalidArgument,
159159
ret_val);
160160

161+
ET_KERNEL_CHECK(
162+
ctx, tensor_is_default_dim_order(input), InvalidArgument, ret_val);
163+
164+
ET_KERNEL_CHECK(
165+
ctx,
166+
tensors_have_same_dim_order(input, out, mean_out, rstd_out),
167+
InvalidArgument,
168+
ret_val);
169+
170+
if (weight.has_value()) {
171+
ET_KERNEL_CHECK(
172+
ctx,
173+
tensors_have_same_dim_order(input, weight.value()),
174+
InvalidArgument,
175+
ret_val);
176+
}
177+
178+
if (bias.has_value()) {
179+
ET_KERNEL_CHECK(
180+
ctx,
181+
tensors_have_same_dim_order(input, bias.value()),
182+
InvalidArgument,
183+
ret_val);
184+
}
185+
161186
constexpr auto name = "native_group_norm.out";
162187

163188
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {

kernels/portable/cpu/op_native_layer_norm.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,33 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
117117
InvalidArgument,
118118
ret_val);
119119

120+
// Only support default dim order for now.
121+
// TODO: Support other dim orders.
122+
ET_KERNEL_CHECK(
123+
ctx, tensor_is_default_dim_order(input), InvalidArgument, ret_val);
124+
125+
ET_KERNEL_CHECK(
126+
ctx,
127+
tensors_have_same_dim_order(input, out, mean_out, rstd_out),
128+
InvalidArgument,
129+
ret_val);
130+
131+
if (weight.has_value()) {
132+
ET_KERNEL_CHECK(
133+
ctx,
134+
tensors_have_same_dim_order(input, weight.value()),
135+
InvalidArgument,
136+
ret_val);
137+
}
138+
139+
if (bias.has_value()) {
140+
ET_KERNEL_CHECK(
141+
ctx,
142+
tensors_have_same_dim_order(input, bias.value()),
143+
InvalidArgument,
144+
ret_val);
145+
}
146+
120147
Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
121148
size_t mean_rstd_ndim = 0;
122149
get_layer_norm_out_target_size(

kernels/portable/cpu/op_ne.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Tensor& ne_tensor_out(
3030
InvalidArgument,
3131
out);
3232

33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
35+
3336
ScalarType a_type = a.scalar_type();
3437
ScalarType b_type = b.scalar_type();
3538
ScalarType out_type = out.scalar_type();
@@ -75,6 +78,9 @@ Tensor& ne_scalar_out(
7578
out,
7679
"Failed to resize output tensor.");
7780

81+
ET_KERNEL_CHECK(
82+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
83+
7884
ScalarType a_type = a.scalar_type();
7985
ScalarType b_type = utils::get_scalar_dtype(b);
8086
ScalarType out_type = out.scalar_type();

kernels/portable/cpu/op_neg.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Tensor& neg_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
3030
ET_KERNEL_CHECK(
3131
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
3232

33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
35+
3336
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
3437
apply_unary_map_fn(
3538
[](const CTYPE val_in) { return static_cast<CTYPE>(-val_in); },

kernels/portable/cpu/op_pdist_forward.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ Tensor& _pdist_forward_out(
2424

2525
ET_KERNEL_CHECK(ctx, check_pdist_args(in, p, out), InvalidArgument, out);
2626

27+
ET_KERNEL_CHECK(
28+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
29+
30+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
31+
2732
Tensor::SizesType target_sizes[kTensorDimensionLimit];
2833
size_t target_ndim = 0;
2934
get_pdist_out_target_size(in, target_sizes, &target_ndim);

kernels/portable/cpu/op_permute_copy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Tensor& permute_copy_out(
4646
ET_KERNEL_CHECK(
4747
ctx, check_permute_copy_args(in, dims, out), InvalidArgument, out);
4848

49+
ET_KERNEL_CHECK(
50+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
51+
4952
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
5053
size_t expected_out_dim = 0;
5154
get_permute_copy_out_target_size(

kernels/portable/cpu/op_pixel_shuffle.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ Tensor& pixel_shuffle_out(
2929
InvalidArgument,
3030
out);
3131

32+
ET_KERNEL_CHECK(
33+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
34+
35+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
36+
3237
const Tensor::SizesType leading_dims = getLeadingDims(in, in.dim() - 3);
3338
const Tensor::SizesType channels = in.size(in.dim() - 3);
3439
const Tensor::SizesType height = in.size(in.dim() - 2);

0 commit comments

Comments
 (0)