Skip to content

Commit a02217e

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
rename default dim order as contiguous dim order (#2157)
Summary: bypass-github-export-checks Reviewed By: digantdesai Differential Revision: D54285070
1 parent 9763bfc commit a02217e

File tree

8 files changed

+29
-29
lines changed

8 files changed

+29
-29
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -523,22 +523,22 @@ bool validate_flash_attention_args(
523523
"Attention mask must be a 2D tensor");
524524

525525
ET_LOG_MSG_AND_RETURN_IF_FALSE(
526-
is_default_dim_order(query.dim_order().data(), query.dim()),
527-
"key cache must be in default dim order");
526+
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
527+
"key cache must be in contiguous dim order");
528528

529529
ET_LOG_MSG_AND_RETURN_IF_FALSE(
530-
is_default_dim_order(key.dim_order().data(), key.dim()),
531-
"value cache must be in default dim order");
530+
is_contiguous_dim_order(key.dim_order().data(), key.dim()),
531+
"value cache must be in contiguous dim order");
532532

533533
ET_LOG_MSG_AND_RETURN_IF_FALSE(
534-
is_default_dim_order(value.dim_order().data(), value.dim()),
535-
"value cache must be in default dim order");
534+
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
535+
"value cache must be in contiguous dim order");
536536

537537
if (attn_mask.has_value()) {
538538
ET_LOG_MSG_AND_RETURN_IF_FALSE(
539-
is_default_dim_order(
539+
is_contiguous_dim_order(
540540
attn_mask.value().dim_order().data(), attn_mask.value().dim()),
541-
"value cache must be in default dim order");
541+
"value cache must be in contiguous dim order");
542542
}
543543

544544
return true;
@@ -590,14 +590,14 @@ bool validate_cache_params(
590590
seq_length,
591591
v_cache.size(2));
592592

593-
// Make sure they are in default dim order
593+
// Make sure they are in contiguous dim order
594594
ET_LOG_MSG_AND_RETURN_IF_FALSE(
595-
is_default_dim_order(k_cache.dim_order().data(), k_cache.dim()),
596-
"key cache must be in default dim order");
595+
is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
596+
"key cache must be in contiguous dim order");
597597

598598
ET_LOG_MSG_AND_RETURN_IF_FALSE(
599-
is_default_dim_order(v_cache.dim_order().data(), v_cache.dim()),
600-
"value cache must be in default dim order");
599+
is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
600+
"value cache must be in contiguous dim order");
601601

602602
return true;
603603
}
@@ -615,9 +615,9 @@ void update_cache(
615615
"projected_value must have batch size of 1");
616616
ET_CHECK_MSG(cache.size(1) == 1, "cache must have batch size of 1");
617617
ET_CHECK_MSG(
618-
is_default_dim_order(
618+
is_contiguous_dim_order(
619619
projected_value.dim_order().data(), projected_value.dim()),
620-
"projected value must be in default dim order");
620+
"projected value must be in contiguous dim order");
621621
const void* projected_value_data = projected_value.const_data_ptr();
622622
void* cache_data = cache.mutable_data_ptr();
623623

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
6666
InvalidArgument,
6767
ret_val);
6868

69-
// For now, only support the default dim order
69+
// For now, only support the contiguous dim order
7070
ET_KERNEL_CHECK(
7171
ctx,
72-
is_default_dim_order(in.dim_order().data(), in.dim_order().size()),
72+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size()),
7373
InvalidArgument,
7474
ret_val);
7575

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class TensorFactory {
292292
* size of this vector must be equal to the product of the elements of
293293
* `sizes`.
294294
* @param[in] dim_order The dim order describing how tensor memory is laid
295-
* out. If empty or not specificed, the function will use a default dim order
295+
* out. If empty or not specificed, the function will use a contiguous dim order
296296
* of {0, 1, 2, 3, ...}
297297
*
298298
* @return A new Tensor with the specified shape and data.
@@ -706,7 +706,7 @@ class TensorFactory {
706706
* size of this vector must be equal to the product of the elements of
707707
* `sizes`.
708708
* @param[in] dim_order The dim order describing how tensor memory is laid
709-
* out. If empty or not specificed, the function will use a default dim order
709+
* out. If empty or not specificed, the function will use a contiguous dim order
710710
* of {0, 1, 2, 3, ...}
711711
*
712712
* @return A new Tensor with the specified shape and data.

runtime/core/exec_aten/util/dim_order_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ bool validate_dim_order(const DimOrderType* dim_order, const size_t dims) {
2929
} // namespace
3030

3131
/**
32-
* Check if a given dim_order array is equivalent to the default dim order of
32+
* Check if a given dim_order array is equivalent to the contiguous dim order of
3333
* {0, 1, 2, 3, ...}
3434
*
3535
* @param[in] dim_order pointer to dim_order array
3636
* @param[in] dims length of the dim_order array
3737
*/
3838
template <typename DimOrderType>
39-
inline bool is_default_dim_order(
39+
inline bool is_contiguous_dim_order(
4040
const DimOrderType* dim_order,
4141
const size_t dims) {
4242
for (int i = 0; i < dims; ++i) {

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@
315315
#define ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(t__) \
316316
({ \
317317
ET_CHECK_MSG( \
318-
is_default_dim_order( \
318+
is_contiguous_dim_order( \
319319
t__.dim_order().data(), t__.dim_order().size()) || \
320320
is_channels_last_dim_order( \
321321
t__.dim_order().data(), t__.dim_order().size()), \

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) {
5959
get_dim_order(t, dim_order, t.dim()) == Error::Ok,
6060
"Failed to retrieve dim order from tensor!");
6161

62-
bool ret_val = is_default_dim_order(dim_order, t.dim()) ||
62+
bool ret_val = is_contiguous_dim_order(dim_order, t.dim()) ||
6363
is_channels_last_dim_order(dim_order, t.dim());
6464

6565
if (!ret_val) {

runtime/core/exec_aten/util/tensor_util_portable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ bool tensor_has_valid_dim_order(torch::executor::Tensor t) {
5555

5656
bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) {
5757
bool ret_val =
58-
is_default_dim_order(t.dim_order().data(), t.dim_order().size()) ||
58+
is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size()) ||
5959
is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size());
6060

6161
if (!ret_val) {

runtime/core/exec_aten/util/test/dim_order_util_test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderTest) {
236236
std::vector<exec_aten::DimOrderType> dim_order(i);
237237
std::iota(dim_order.begin(), dim_order.end(), 0);
238238

239-
EXPECT_TRUE(torch::executor::is_default_dim_order(
239+
EXPECT_TRUE(torch::executor::is_contiguous_dim_order(
240240
dim_order.data(), dim_order.size()));
241241

242242
// As a bonus, check that is_channels_last returns false
@@ -252,7 +252,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderFailCasesTest) {
252252
std::iota(dim_order.begin(), dim_order.end(), 0);
253253
std::swap(dim_order[0], dim_order[1]);
254254

255-
EXPECT_FALSE(torch::executor::is_default_dim_order(
255+
EXPECT_FALSE(torch::executor::is_contiguous_dim_order(
256256
dim_order.data(), dim_order.size()));
257257
}
258258

@@ -263,7 +263,7 @@ TEST(TensorUtilTest, IsDefaultDimOrderFailCasesTest) {
263263
dim_order[d] = (d + 1) % i;
264264
}
265265

266-
EXPECT_FALSE(torch::executor::is_default_dim_order(
266+
EXPECT_FALSE(torch::executor::is_contiguous_dim_order(
267267
dim_order.data(), dim_order.size()));
268268
}
269269
}
@@ -276,8 +276,8 @@ TEST(TensorUtilTest, IsChannelsLastDimOrderTest) {
276276
EXPECT_TRUE(torch::executor::is_channels_last_dim_order(dim_order_5d, 5));
277277

278278
// As a bonus, check that is_default returns false
279-
EXPECT_FALSE(torch::executor::is_default_dim_order(dim_order_4d, 4));
280-
EXPECT_FALSE(torch::executor::is_default_dim_order(dim_order_5d, 5));
279+
EXPECT_FALSE(torch::executor::is_contiguous_dim_order(dim_order_4d, 4));
280+
EXPECT_FALSE(torch::executor::is_contiguous_dim_order(dim_order_5d, 5));
281281
}
282282

283283
TEST(TensorUtilTest, IsChannelsLastDimOrderFailCasesTest) {

0 commit comments

Comments
 (0)