Skip to content

Commit b2b0b6a

Browse files
[ET][Portable] Fix op permute_copy
Pull Request resolved: #726 Fix the check for duplicates in dimension list ghstack-source-id: 203418805 @exported-using-ghexport Differential Revision: [D50089231](https://our.internmc.facebook.com/intern/diff/D50089231/)
1 parent 95c7fbf commit b2b0b6a

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,25 @@ bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
152152
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
153153

154154
// Make sure no dimensions are duplicated and all in the range [-in.dim(),
155-
// in.dim() - 1]. Use gaussian sum to check this.
156-
size_t expected_sum = (dims.size() * (dims.size() + 1)) / 2;
157-
size_t gauss_sum = 0;
155+
// in.dim() - 1].
156+
bool dim_exist[kTensorDimensionLimit];
157+
memset(dim_exist, false, sizeof(dim_exist));
158+
158159
for (int i = 0; i < dims.size(); i++) {
159-
// Convert dimension to a non-negative number. dim_base is in the range
160+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dims[i]));
161+
// Convert dimension to a non-negative number in the range
160162
// [0 .. in.dim() - 1].
161-
size_t dim = dims[i] > -1 ? dims[i] : in.dim() + dims[i];
162-
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
163-
gauss_sum += dim + 1;
164-
}
163+
size_t dim = dims[i] >= 0 ? dims[i] : in.dim() + dims[i];
165164

166-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
167-
gauss_sum == expected_sum,
168-
"The dims passed to permute_copy must contain one of each dim!");
165+
// Internal check, since we have already validated this
166+
ET_CHECK(dim < kTensorDimensionLimit && dim >= 0);
167+
168+
// Check that the dimension hasn't been seen previously.
169+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
170+
dim_exist[dim] == false, "duplicate dims are not allowed.");
171+
172+
dim_exist[dim] = true;
173+
}
169174

170175
return true;
171176
}

kernels/test/op_permute_copy_test.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,20 @@ TEST(OpPermuteCopyKernelTest, DupeDimensionPos) {
327327
t_int, ArrayRef<int64_t>(new_dim.data(), new_dim.size()), out));
328328
}
329329

330+
TEST(OpPermuteCopyKernelTest, DupeDimensionPos2) {
331+
TensorFactory<ScalarType::Int> tf;
332+
333+
const std::vector<int64_t> new_dim = {1, 1, 1};
334+
335+
const std::vector<int32_t> sizes = {1, 1, 1};
336+
Tensor t_int = tf.make(sizes, {1});
337+
338+
Tensor out = tf.zeros(sizes);
339+
340+
ET_EXPECT_KERNEL_FAILURE(op_permute_copy_out(
341+
t_int, ArrayRef<int64_t>(new_dim.data(), new_dim.size()), out));
342+
}
343+
330344
TEST(OpPermuteCopyKernelTest, DupeDimensionNeg) {
331345
TensorFactory<ScalarType::Int> tf;
332346

@@ -341,6 +355,20 @@ TEST(OpPermuteCopyKernelTest, DupeDimensionNeg) {
341355
t_int, ArrayRef<int64_t>(new_dim.data(), new_dim.size()), out));
342356
}
343357

358+
TEST(OpPermuteCopyKernelTest, DupeDimensionNeg2) {
359+
TensorFactory<ScalarType::Int> tf;
360+
361+
const std::vector<int64_t> new_dim = {0, 1, -5};
362+
363+
const std::vector<int32_t> sizes = {1, 1, 1};
364+
Tensor t_int = tf.make(sizes, {1});
365+
366+
Tensor out = tf.zeros(sizes);
367+
368+
ET_EXPECT_KERNEL_FAILURE(op_permute_copy_out(
369+
t_int, ArrayRef<int64_t>(new_dim.data(), new_dim.size()), out));
370+
}
371+
344372
TEST(OpPermuteCopyKernelTest, MismatchDim) {
345373
TensorFactory<ScalarType::Int> tf;
346374

0 commit comments

Comments
 (0)