Skip to content

Commit 5f2c55e

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
introduce dim order tests to op test (#2637)
Summary: Pull Request resolved: #2637 This diff introduces dim order sanity check utils, as well as dim-order related test to operator tests, to help our system maintain its correctness when introducing new dim order ([0, 2, 3, 1]) which we never support before. The goal is checking whether or not every operator support its input's memory format, and using related tests for regular tests. The high levels of sanity check and test will be: 1. the dim order of input and output should be same. 2. the dim order of all input tensors should be same, unless operaotr-specific requirement for some input (e.g. some operator may request some input have to be contiguous, although I haven't found the actual example yet.) 3. make the operator support as much dim order as possible (e,g, if a operator can support both contiguous and channels last, then the sanity check has to make the both input valid.) I also updated `op_abs` in this diff to demonstrate how the sanity check as well as tests will be inserted. Differential Revision: https://internalfb.com/D55227304
1 parent 6fce77f commit 5f2c55e

File tree

10 files changed

+441
-64
lines changed

10 files changed

+441
-64
lines changed

kernels/portable/cpu/op_abs.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Tensor& abs_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
2828
"Failed to resize output tensor.");
2929

3030
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3133

3234
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
3335
apply_unary_map_fn(

kernels/test/TestUtil.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@
3030
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
3131
EXPECT_ANY_THROW(_statement)
3232

33+
#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
34+
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
35+
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
36+
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
37+
\
38+
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
39+
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
40+
\
41+
Tensor ret = op(input_channels_last, output_channels_last); \
42+
if (channels_last_support) { \
43+
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
44+
} else { \
45+
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
46+
} \
47+
EXPECT_TENSOR_EQ(output_channels_last, ret);
48+
3349
#else
3450

3551
#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
@@ -52,6 +68,26 @@
5268
} \
5369
} while (false)
5470

71+
#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
72+
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
73+
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
74+
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
75+
\
76+
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
77+
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
78+
\
79+
Tensor ret = op(input_channels_last, output_channels_last); \
80+
if (channels_last_support) { \
81+
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
82+
} else { \
83+
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
84+
} \
85+
EXPECT_TENSOR_EQ(output_channels_last, ret); \
86+
ET_EXPECT_KERNEL_FAILURE( \
87+
context_, op(input_channels_last, output_contiguous)); \
88+
ET_EXPECT_KERNEL_FAILURE( \
89+
context_, op(input_contiguous, output_channels_last));
90+
5591
#endif // USE_ATEN_LIB
5692

5793
/*

kernels/test/op_abs_test.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,28 @@ TEST_F(OpAbsTest, SanityCheck) {
3838
EXPECT_TENSOR_EQ(out, ret);
3939
EXPECT_TENSOR_EQ(out, expected);
4040
}
41+
42+
TEST_F(OpAbsTest, MemoryFormatCheck) {
43+
TensorFactory<ScalarType::Float> tf;
44+
45+
std::vector<int32_t> sizes = {2, 3, 1, 5};
46+
47+
Tensor input_contiguous =
48+
tf.make(sizes, {0.8737, 0.5359, 0.3743, -0.3040, -0.7800, -0.2306,
49+
-0.7684, -0.5364, 0.3478, -0.3289, 0.0829, 0.2939,
50+
-0.8211, 0.8572, -0.0802, 0.9252, -0.2093, 0.9013,
51+
-0.4197, 0.3987, -0.5291, -0.5567, 0.2691, 0.7819,
52+
-0.8009, -0.4286, -0.9299, 0.2143, 0.2565, -0.5701});
53+
Tensor expected_contiguous = tf.make(
54+
sizes, {0.8737, 0.5359, 0.3743, 0.3040, 0.7800, 0.2306, 0.7684, 0.5364,
55+
0.3478, 0.3289, 0.0829, 0.2939, 0.8211, 0.8572, 0.0802, 0.9252,
56+
0.2093, 0.9013, 0.4197, 0.3987, 0.5291, 0.5567, 0.2691, 0.7819,
57+
0.8009, 0.4286, 0.9299, 0.2143, 0.2565, 0.5701});
58+
59+
ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(
60+
tf,
61+
op_abs_out,
62+
input_contiguous,
63+
expected_contiguous,
64+
/*channels_last_support=*/true);
65+
}

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 144 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
#pragma once
44

55
#include <algorithm>
6+
#include <cstdint>
67

78
#include <executorch/runtime/core/exec_aten/exec_aten.h>
9+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
810
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
911
#include <executorch/runtime/core/tensor_shape_dynamism.h>
1012
#include <executorch/runtime/platform/assert.h>
@@ -54,7 +56,7 @@ inline size_t sizes_to_numel(const std::vector<int32_t>& sizes) {
5456

5557
inline bool check_strides(
5658
const std::vector<int32_t> sizes,
57-
const std::vector<int32_t> strides) {
59+
const std::vector<exec_aten::StridesType> strides) {
5860
if (sizes.size() != strides.size()) {
5961
// The length of stride vector shall equal to size vector.
6062
return false;
@@ -147,14 +149,14 @@ inline bool check_dim_order(
147149
return true;
148150
}
149151

150-
inline std::vector<int32_t> strides_from_dim_order(
152+
inline std::vector<exec_aten::StridesType> strides_from_dim_order(
151153
const std::vector<int32_t>& sizes,
152154
const std::vector<uint8_t>& dim_order) {
153155
bool legal = check_dim_order(sizes, dim_order);
154156
ET_CHECK_MSG(legal, "The input dim_order variable is illegal.");
155157

156158
size_t ndim = sizes.size();
157-
std::vector<int32_t> strides(ndim);
159+
std::vector<exec_aten::StridesType> strides(ndim);
158160
strides[dim_order[ndim - 1]] = 1;
159161
for (int i = ndim - 2; i >= 0; --i) {
160162
uint8_t cur_dim = dim_order[i];
@@ -258,7 +260,7 @@ class TensorFactory {
258260
at::Tensor make(
259261
const std::vector<int32_t>& sizes,
260262
const std::vector<ctype>& data,
261-
const std::vector<int32_t> strides = {},
263+
const std::vector<exec_aten::StridesType> strides = {},
262264
__ET_UNUSED TensorShapeDynamism dynamism =
263265
TensorShapeDynamism::DYNAMIC_UNBOUND) {
264266
auto expected_numel = internal::sizes_to_numel(sizes);
@@ -344,6 +346,71 @@ class TensorFactory {
344346
sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism);
345347
}
346348

349+
/**
350+
* Given data in contiguous memory format, returns a new Tensor with the
351+
* specified shape and the same data but in channels last memory format.
352+
*
353+
* @param[in] sizes The sizes of the dimensions of the Tensor.
354+
* @param[in] data The data in contiguous memory format that the Tensor should
355+
* be initialized with. The size of this vector must be equal to the product
356+
* of the elements of `sizes`.
357+
*
358+
* @return A new Tensor with the specified shape and data in channls last
359+
* memory format.
360+
*/
361+
at::Tensor channels_last_like(
362+
const Tensor& input,
363+
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
364+
ET_CHECK_MSG(input.sizes().size() == 4, "Only 4D tensors can be channels last");
365+
366+
const std::vector<int32_t> sizes(
367+
input.sizes().begin(), input.sizes().end());
368+
369+
std::vector<uint8_t> contiguous_dim_order(sizes.size());
370+
for (uint8_t i = 0; i < sizes.size(); i++) {
371+
contiguous_dim_order[i] = i;
372+
}
373+
std::vector<exec_aten::StridesType> contiguous_strides =
374+
internal::strides_from_dim_order(sizes, contiguous_dim_order);
375+
376+
for (int32_t i = 0; i < input.dim(); i++) {
377+
ET_CHECK_MSG(
378+
input.strides()[i] == contiguous_strides[i],
379+
"Input tensor is not contiguous");
380+
}
381+
382+
int32_t N = sizes[0];
383+
int32_t C = sizes[1];
384+
int32_t H = sizes[2];
385+
int32_t W = sizes[3];
386+
387+
std::vector<ctype> contiguous_data(
388+
input.data_ptr<ctype>(), input.data_ptr<ctype>() + input.numel());
389+
std::vector<ctype> channels_last_data(
390+
N * C * H * W); // Create a new blob with the same total size to contain
391+
// channels_last data
392+
for (int32_t n = 0; n < N; ++n) {
393+
for (int32_t c = 0; c < C; ++c) {
394+
for (int32_t h = 0; h < H; ++h) {
395+
for (int32_t w = 0; w < W; ++w) {
396+
// Calculate the index in the original blob
397+
int32_t old_index = ((n * C + c) * H + h) * W + w;
398+
// Calculate the index in the new blob
399+
int32_t new_index = ((n * H + h) * W + w) * C + c;
400+
// Copy the data
401+
channels_last_data[new_index] = contiguous_data[old_index];
402+
}
403+
}
404+
}
405+
}
406+
407+
return make_with_dimorder(
408+
sizes,
409+
channels_last_data,
410+
internal::channels_last_dim_order(sizes.size()),
411+
dynamism);
412+
}
413+
347414
/**
348415
* Returns a new Tensor with the specified shape, containing contiguous
349416
* data will all elements set to `value`.
@@ -459,14 +526,13 @@ class TensorFactory {
459526
*/
460527
at::Tensor empty_strided(
461528
const std::vector<int32_t>& sizes,
462-
const std::vector<int32_t>& strides,
529+
const std::vector<exec_aten::StridesType>& strides,
463530
__ET_UNUSED TensorShapeDynamism dynamism =
464531
TensorShapeDynamism::DYNAMIC_UNBOUND) {
465532
auto sizes64 = vec_32_to_64(sizes);
466-
auto strides64 = vec_32_to_64(strides);
467533
return at::empty_strided(
468534
sizes64,
469-
strides64,
535+
strides,
470536
DTYPE,
471537
/*layout_opt=*/at::Layout::Strided,
472538
/*device_opt=*/at::Device(at::DeviceType::CPU),
@@ -665,7 +731,7 @@ class TensorFactory {
665731
torch::executor::Tensor make(
666732
const std::vector<int32_t>& sizes,
667733
const std::vector<ctype>& data,
668-
const std::vector<int32_t> strides = {},
734+
const std::vector<exec_aten::StridesType> strides = {},
669735
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
670736
std::vector<int32_t> default_strides;
671737
// Generate strides from the tensor dimensions, assuming contiguous data if
@@ -745,7 +811,7 @@ class TensorFactory {
745811

746812
/**
747813
* Returns a new Tensor with the specified shape and data in channels last
748-
* memory layout.
814+
* memory format.
749815
*
750816
* @param[in] sizes The sizes of the dimensions of the Tensor.
751817
* @param[in] data The data that the Tensor should be initialized with. The
@@ -763,6 +829,60 @@ class TensorFactory {
763829
sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism);
764830
}
765831

832+
/**
833+
* Given data in contiguous memory format, returns a new Tensor with the
834+
* specified shape and the same data but in channels last memory format.
835+
*
836+
* @param[in] sizes The sizes of the dimensions of the Tensor.
837+
* @param[in] data The data in contiguous memory format that the Tensor should
838+
* be initialized with. The size of this vector must be equal to the product
839+
* of the elements of `sizes`.
840+
*
841+
* @return A new Tensor with the specified shape and data in channls last
842+
* memory format.
843+
*/
844+
torch::executor::Tensor channels_last_like(
845+
const Tensor& input,
846+
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
847+
const std::vector<int32_t> sizes(
848+
input.sizes().begin(), input.sizes().end());
849+
850+
ET_CHECK_MSG(sizes.size() == 4, "Only 4D tensors can be channels last");
851+
ET_CHECK_MSG(
852+
is_contiguous_dim_order(input.dim_order().data(), input.dim()) == true,
853+
"Input tensor is not contiguous");
854+
int32_t N = sizes[0];
855+
int32_t C = sizes[1];
856+
int32_t H = sizes[2];
857+
int32_t W = sizes[3];
858+
859+
std::vector<ctype> contiguous_data(
860+
input.data_ptr<ctype>(), input.data_ptr<ctype>() + input.numel());
861+
std::vector<ctype> channels_last_data(
862+
N * C * H * W); // Create a new blob with the same total size to contain
863+
// channels_last data
864+
for (int32_t n = 0; n < N; ++n) {
865+
for (int32_t c = 0; c < C; ++c) {
866+
for (int32_t h = 0; h < H; ++h) {
867+
for (int32_t w = 0; w < W; ++w) {
868+
// Calculate the index in the original blob
869+
int32_t old_index = ((n * C + c) * H + h) * W + w;
870+
// Calculate the index in the new blob
871+
int32_t new_index = ((n * H + h) * W + w) * C + c;
872+
// Copy the data
873+
channels_last_data[new_index] = contiguous_data[old_index];
874+
}
875+
}
876+
}
877+
}
878+
879+
return make_with_dimorder(
880+
sizes,
881+
channels_last_data,
882+
internal::channels_last_dim_order(sizes.size()),
883+
dynamism);
884+
}
885+
766886
/**
767887
* Returns a new Tensor with the specified shape, containing contiguous data
768888
* will all elements set to `value`.
@@ -798,7 +918,20 @@ class TensorFactory {
798918

799919
/**
800920
* Returns a new Tensor with the specified shape, containing contiguous data
801-
* with all `0` elements.
921+
* in channels last memory format with all `0` elements.
922+
*
923+
* @param[in] sizes The sizes of the dimensions of the Tensor.
924+
* @return A new Tensor with the specified shape.
925+
*/
926+
torch::executor::Tensor zeros_channels_last(
927+
const std::vector<int32_t>& sizes,
928+
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
929+
return full_channels_last(sizes, 0, dynamism);
930+
}
931+
932+
/**
933+
* Returns a new Tensor with the specified shape, containing contiguous data
934+
* in contiguous memory format with all `0` elements.
802935
*
803936
* @param[in] sizes The sizes of the dimensions of the Tensor.
804937
* @return A new Tensor with the specified shape.
@@ -877,7 +1010,7 @@ class TensorFactory {
8771010
std::vector<int32_t> sizes_;
8781011
std::vector<ctype> data_;
8791012
std::vector<uint8_t> dim_order_;
880-
std::vector<int32_t> strides_;
1013+
std::vector<exec_aten::StridesType> strides_;
8811014
TensorImpl impl_;
8821015
};
8831016

runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ TEST_F(TensorFactoryTest, MakeStridedDataIsCopied) {
448448

449449
// Create two tensors using the same input data and strided vector.
450450
std::vector<int32_t> data = {1, 2, 3, 4};
451-
std::vector<int32_t> strides = {1, 2};
451+
std::vector<exec_aten::StridesType> strides = {1, 2};
452452
Tensor t1 = tf.make(/*sizes=*/{2, 2}, data, strides);
453453
Tensor t2 = tf.make(/*sizes=*/{2, 2}, data, strides);
454454

0 commit comments

Comments
 (0)