Skip to content

Commit 5d395b2

Browse files
committed
roll out s/IF_FALSE/UNLESS/ for ET_LOG macros
Separating the big find/replace for the previous diff to ease review; this is just the result of ``` fastmod "(ET_LOG_(MSG_)?AND_RETURN_)IF_FALSE" "\${1}UNLESS" ``` with a manual revert in tensor_util.h for the two backward-compatibility defines. ghstack-source-id: a4a76d6 ghstack-comment-id: 2644056878 Pull Request resolved: #8318
1 parent 51d6db2 commit 5d395b2

37 files changed

+522
-526
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -594,46 +594,46 @@ bool validate_flash_attention_args(
594594
const Tensor& key,
595595
const Tensor& value,
596596
const optional<Tensor>& attn_mask) {
597-
ET_LOG_MSG_AND_RETURN_IF_FALSE(query.dim() == 4, "query must be a 4D tensor");
598-
ET_LOG_MSG_AND_RETURN_IF_FALSE(key.dim() == 4, "key must be a 4D tensor");
599-
ET_LOG_MSG_AND_RETURN_IF_FALSE(value.dim() == 4, "value must be a 4D tensor");
597+
ET_LOG_MSG_AND_RETURN_UNLESS(query.dim() == 4, "query must be a 4D tensor");
598+
ET_LOG_MSG_AND_RETURN_UNLESS(key.dim() == 4, "key must be a 4D tensor");
599+
ET_LOG_MSG_AND_RETURN_UNLESS(value.dim() == 4, "value must be a 4D tensor");
600600

601601
// Sizes
602-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
602+
ET_LOG_MSG_AND_RETURN_UNLESS(
603603
(query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
604604
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
605605

606-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
606+
ET_LOG_MSG_AND_RETURN_UNLESS(
607607
(query.scalar_type() == ScalarType::Float), "Query must be Float type");
608608

609-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
609+
ET_LOG_MSG_AND_RETURN_UNLESS(
610610
(query.scalar_type() == key.scalar_type()) &&
611611
(query.scalar_type() == value.scalar_type()),
612612
"Key and Value must have the same data type as Query");
613613

614-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
614+
ET_LOG_MSG_AND_RETURN_UNLESS(
615615
!attn_mask.has_value() || attn_mask.value().dim() == 2,
616616
"Attention mask must be a 2D tensor");
617617

618-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
618+
ET_LOG_MSG_AND_RETURN_UNLESS(
619619
!attn_mask.has_value() ||
620620
attn_mask.value().scalar_type() == query.scalar_type(),
621621
"Attention mask must be a 2D tensor");
622622

623-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
623+
ET_LOG_MSG_AND_RETURN_UNLESS(
624624
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
625625
"key cache must be in contiguous dim order");
626626

627-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
627+
ET_LOG_MSG_AND_RETURN_UNLESS(
628628
is_contiguous_dim_order(key.dim_order().data(), key.dim()),
629629
"value cache must be in contiguous dim order");
630630

631-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
631+
ET_LOG_MSG_AND_RETURN_UNLESS(
632632
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
633633
"value cache must be in contiguous dim order");
634634

635635
if (attn_mask.has_value()) {
636-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
636+
ET_LOG_MSG_AND_RETURN_UNLESS(
637637
is_contiguous_dim_order(
638638
attn_mask.value().dim_order().data(), attn_mask.value().dim()),
639639
"value cache must be in contiguous dim order");
@@ -647,21 +647,21 @@ bool validate_cache_params(
647647
const Tensor& v_cache,
648648
int64_t start_pos,
649649
int64_t seq_length) {
650-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
650+
ET_LOG_MSG_AND_RETURN_UNLESS(
651651
k_cache.dim() == 4, "kcache must be a 4D tensor");
652652

653-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
653+
ET_LOG_MSG_AND_RETURN_UNLESS(
654654
v_cache.dim() == 4, "v_cache must be a 4D tensor");
655655

656-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
656+
ET_LOG_MSG_AND_RETURN_UNLESS(
657657
start_pos < k_cache.size(1),
658658
"start_pos must be less than key cache at dim 1");
659659

660-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
660+
ET_LOG_MSG_AND_RETURN_UNLESS(
661661
start_pos < v_cache.size(1),
662662
"start_pos must be less than value cache at dim 1");
663663

664-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
664+
ET_LOG_MSG_AND_RETURN_UNLESS(
665665
(start_pos + seq_length) <= k_cache.size(1),
666666
"start_post + seq_length must be less than max seq length supported by key cache."
667667
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -671,7 +671,7 @@ bool validate_cache_params(
671671
seq_length,
672672
k_cache.size(1));
673673

674-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
674+
ET_LOG_MSG_AND_RETURN_UNLESS(
675675
(start_pos + seq_length) <= v_cache.size(1),
676676
"start_post + seq_length must be less than max seq length supported by key cache."
677677
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -682,11 +682,11 @@ bool validate_cache_params(
682682
v_cache.size(1));
683683

684684
// Make sure they are in contiguous dim order
685-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
685+
ET_LOG_MSG_AND_RETURN_UNLESS(
686686
is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
687687
"key cache must be in contiguous dim order");
688688

689-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
689+
ET_LOG_MSG_AND_RETURN_UNLESS(
690690
is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
691691
"value cache must be in contiguous dim order");
692692

extension/llm/custom_ops/op_tile_crop.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ bool check_tile_crop_out_args(
1919
const Tensor& in,
2020
int64_t tile_size,
2121
Tensor& out) {
22-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
23-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
24-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 4));
25-
ET_LOG_AND_RETURN_IF_FALSE(tile_size > 0);
26-
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % tile_size == 0);
27-
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % tile_size == 0);
22+
ET_LOG_AND_RETURN_UNLESS(tensors_have_same_dtype(in, out));
23+
ET_LOG_AND_RETURN_UNLESS(tensor_is_rank(in, 3));
24+
ET_LOG_AND_RETURN_UNLESS(tensor_is_rank(out, 4));
25+
ET_LOG_AND_RETURN_UNLESS(tile_size > 0);
26+
ET_LOG_AND_RETURN_UNLESS(in.size(in.dim() - 1) % tile_size == 0);
27+
ET_LOG_AND_RETURN_UNLESS(in.size(in.dim() - 2) % tile_size == 0);
2828
return true;
2929
}
3030

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ bool validate_cache_params(
2525
const Tensor& quantized_cache,
2626
int64_t start_pos,
2727
int64_t seq_length) {
28-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
28+
ET_LOG_MSG_AND_RETURN_UNLESS(
2929
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
3030

31-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
31+
ET_LOG_MSG_AND_RETURN_UNLESS(
3232
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
3333

34-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
34+
ET_LOG_MSG_AND_RETURN_UNLESS(
3535
start_pos < quantized_cache.size(1),
3636
"start_pos must be less than cache size at dim 1");
3737

38-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
38+
ET_LOG_MSG_AND_RETURN_UNLESS(
3939
(start_pos + seq_length) <= quantized_cache.size(1),
4040
"start_post + seq_length must be less than max seq length supported by cache."
4141
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -46,12 +46,12 @@ bool validate_cache_params(
4646
quantized_cache.size(1));
4747

4848
// Make sure they are in contiguous dim order
49-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
49+
ET_LOG_MSG_AND_RETURN_UNLESS(
5050
is_contiguous_dim_order(
5151
quantized_cache.dim_order().data(), quantized_cache.dim()),
5252
"quantized cache must be in contiguous dim order");
5353

54-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
54+
ET_LOG_MSG_AND_RETURN_UNLESS(
5555
is_contiguous_dim_order(
5656
quantized_value.dim_order().data(), quantized_value.dim()),
5757
"quantized value must be in contiguous dim order");

extension/parallel/thread_parallel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ bool parallel_for(
5353
const int64_t end,
5454
const int64_t grain_size,
5555
const std::function<void(int64_t, int64_t)>& f) {
56-
ET_LOG_AND_RETURN_IF_FALSE(begin >= 0 && end >= 0);
57-
ET_LOG_AND_RETURN_IF_FALSE(end >= begin);
58-
ET_LOG_AND_RETURN_IF_FALSE(grain_size > 0);
56+
ET_LOG_AND_RETURN_UNLESS(begin >= 0 && end >= 0);
57+
ET_LOG_AND_RETURN_UNLESS(end >= begin);
58+
ET_LOG_AND_RETURN_UNLESS(grain_size > 0);
5959
int64_t num_tasks = 0, chunk_size = 0;
6060
std::tie(num_tasks, chunk_size) =
6161
calc_num_tasks_and_chunk_size(begin, end, grain_size);

kernels/aten/cpu/op__empty_dim_order.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ inline bool _check__empty_out_dim_order(
4444
}
4545

4646
// dim order size shall equal to input dim
47-
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == out.dim());
47+
ET_LOG_AND_RETURN_UNLESS(dim_order_ref.size() == out.dim());
4848

49-
ET_LOG_AND_RETURN_IF_FALSE(
49+
ET_LOG_AND_RETURN_UNLESS(
5050
is_channels_last_dim_order(dim_order_ref.data(), dim_order_ref.size()) ||
5151
is_contiguous_dim_order(dim_order_ref.data(), dim_order_ref.size()));
5252

53-
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
53+
ET_LOG_AND_RETURN_UNLESS(kMaxNumOfDimensions >= out.dim());
5454
executorch::aten::StridesType target_strides[kMaxNumOfDimensions];
5555
dim_order_to_stride_nocheck(
5656
out.sizes().data(),
@@ -59,7 +59,7 @@ inline bool _check__empty_out_dim_order(
5959
target_strides);
6060

6161
for (size_t i = 0; i < dim_order_ref.size(); i++) {
62-
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
62+
ET_LOG_AND_RETURN_UNLESS(target_strides[i] == out.strides()[i]);
6363
}
6464

6565
return true;

kernels/aten/cpu/op__to_dim_order_copy.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,43 +47,43 @@ bool check__to_dim_order_copy_args(
4747
executorch::aten::OptionalArrayRef<int64_t> dim_order,
4848
Tensor& out) {
4949
// Right now we only support blocking data transfer
50-
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);
50+
ET_LOG_AND_RETURN_UNLESS(non_blocking == false);
5151

5252
// dim_order is set, the target dim_order will be either contiguous or
5353
// channels_last memory format
5454
if (dim_order.has_value()) {
5555
executorch::aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();
5656

5757
// dim order size shall equal to input dim
58-
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());
58+
ET_LOG_AND_RETURN_UNLESS(dim_order_ref.size() == input.dim());
5959

60-
ET_LOG_AND_RETURN_IF_FALSE(
60+
ET_LOG_AND_RETURN_UNLESS(
6161
is_channels_last_dim_order(
6262
dim_order.value().data(), dim_order.value().size()) ||
6363
is_contiguous_dim_order(
6464
dim_order.value().data(), dim_order.value().size()));
6565

6666
// Out Aten tensor shall have same memory format stride as dim_order
6767
const size_t kMaxNumOfDimensions = 16;
68-
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
68+
ET_LOG_AND_RETURN_UNLESS(kMaxNumOfDimensions >= out.dim());
6969
executorch::aten::StridesType target_strides[kMaxNumOfDimensions];
7070
dim_order_to_stride_nocheck(
7171
out.sizes().data(),
7272
dim_order_ref.data(),
7373
dim_order_ref.size(),
7474
target_strides);
75-
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
75+
ET_LOG_AND_RETURN_UNLESS(out.dim() == dim_order_ref.size());
7676
for (size_t i = 0; i < dim_order_ref.size(); i++) {
77-
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
77+
ET_LOG_AND_RETURN_UNLESS(target_strides[i] == out.strides()[i]);
7878
}
7979

8080
} else { // dim_order is not set, preserve the dim order of input
8181

8282
auto out_strides = out.strides();
8383
auto input_strides = input.strides();
84-
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
84+
ET_LOG_AND_RETURN_UNLESS(input_strides.size() == out_strides.size());
8585
for (size_t i = 0; i < input_strides.size(); i++) {
86-
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
86+
ET_LOG_AND_RETURN_UNLESS(input_strides[i] == out_strides[i]);
8787
}
8888
}
8989
return true;

kernels/aten/cpu/util/copy_ops_util.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,43 @@ bool check__to_dim_order_copy_args(
2222
executorch::aten::OptionalArrayRef<int64_t> dim_order,
2323
Tensor& out) {
2424
// Right now we only support blocking data transfer
25-
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);
25+
ET_LOG_AND_RETURN_UNLESS(non_blocking == false);
2626

2727
// dim_order is set, the target dim_order will be either contiguous or
2828
// channels_last memory format
2929
if (dim_order.has_value()) {
3030
executorch::aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();
3131

3232
// dim order size shall equal to input dim
33-
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());
33+
ET_LOG_AND_RETURN_UNLESS(dim_order_ref.size() == input.dim());
3434

35-
ET_LOG_AND_RETURN_IF_FALSE(
35+
ET_LOG_AND_RETURN_UNLESS(
3636
is_channels_last_dim_order(
3737
dim_order.value().data(), dim_order.value().size()) ||
3838
is_contiguous_dim_order(
3939
dim_order.value().data(), dim_order.value().size()));
4040

4141
// Out Aten tensor shall have same memory format stride as dim_order
4242
const size_t kMaxNumOfDimensions = 16;
43-
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
43+
ET_LOG_AND_RETURN_UNLESS(kMaxNumOfDimensions >= out.dim());
4444
executorch::aten::StridesType target_strides[kMaxNumOfDimensions];
4545
dim_order_to_stride_nocheck(
4646
out.sizes().data(),
4747
dim_order_ref.data(),
4848
dim_order_ref.size(),
4949
target_strides);
50-
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
50+
ET_LOG_AND_RETURN_UNLESS(out.dim() == dim_order_ref.size());
5151
for (size_t i = 0; i < dim_order_ref.size(); i++) {
52-
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
52+
ET_LOG_AND_RETURN_UNLESS(target_strides[i] == out.strides()[i]);
5353
}
5454

5555
} else { // dim_order is not set, preserve the dim order of input
5656

5757
auto out_strides = out.strides();
5858
auto input_strides = input.strides();
59-
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
59+
ET_LOG_AND_RETURN_UNLESS(input_strides.size() == out_strides.size());
6060
for (size_t i = 0; i < input_strides.size(); i++) {
61-
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
61+
ET_LOG_AND_RETURN_UNLESS(input_strides[i] == out_strides[i]);
6262
}
6363
}
6464
return true;

kernels/optimized/cpu/op_bmm.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,46 +31,46 @@ namespace {
3131
// Verifies that the parameters are valid.
3232
bool check_bmm_out_args(const Tensor& self, const Tensor& mat2, Tensor& out) {
3333
// Ensure dimensions is 3 for all input and out
34-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
34+
ET_LOG_MSG_AND_RETURN_UNLESS(
3535
self.dim() == mat2.dim(),
3636
"self.dim() %zd != mat2.dim() %zd",
3737
self.dim(),
3838
mat2.dim());
39-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
39+
ET_LOG_MSG_AND_RETURN_UNLESS(
4040
self.dim() == out.dim(),
4141
"self.dim() %zd != out.dim() %zd",
4242
self.dim(),
4343
out.dim());
44-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
44+
ET_LOG_MSG_AND_RETURN_UNLESS(
4545
self.dim() == 3, "self.dim() %zd != 3", self.dim());
4646
// Ensure batch larger than or equals to 0
47-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
47+
ET_LOG_MSG_AND_RETURN_UNLESS(
4848
self.size(0) >= 0, "self.size(0) %zd < 0", self.size(0));
4949
// Ensure batches are the same
50-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
50+
ET_LOG_MSG_AND_RETURN_UNLESS(
5151
self.size(0) == mat2.size(0),
5252
"self.size(0) %zd != mat2.size(0) %zd",
5353
self.size(0),
5454
mat2.size(0));
55-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
55+
ET_LOG_MSG_AND_RETURN_UNLESS(
5656
self.size(0) == out.size(0),
5757
"self.size(0) %zd != out.size(0) %zd",
5858
self.size(0),
5959
out.size(0));
6060
// Ensure the out size is compatible with input tensors
61-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
61+
ET_LOG_MSG_AND_RETURN_UNLESS(
6262
mat2.size(2) == out.size(2),
6363
"mat2.size(2) %zd != out.size(2) %zd",
6464
mat2.size(2),
6565
out.size(2));
66-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
66+
ET_LOG_MSG_AND_RETURN_UNLESS(
6767
self.size(1) == out.size(1),
6868
"self.size(1) %zd != out.size(1) %zd",
6969
self.size(1),
7070
out.size(1));
7171

7272
// Ensure that all tensors share a dtype
73-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, mat2, out));
73+
ET_LOG_AND_RETURN_UNLESS(tensors_have_same_dtype(self, mat2, out));
7474

7575
return true;
7676
}

kernels/portable/cpu/op__empty_dim_order.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,20 @@ bool _check__empty_out_dim_order(OptionalIntArrayRef dim_order, Tensor& out) {
3030
// out tensor's dim order shall equal to input dim order
3131
IntArrayRef dim_order_ref = dim_order.value();
3232

33-
ET_LOG_AND_RETURN_IF_FALSE(
33+
ET_LOG_AND_RETURN_UNLESS(
3434
is_channels_last_dim_order(
3535
dim_order.value().data(), dim_order.value().size()) ||
3636
is_contiguous_dim_order(
3737
dim_order.value().data(), dim_order.value().size()));
3838

3939
// Out tensor shall have same dim order as dim_order
40-
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == dim_order_ref.size());
40+
ET_LOG_AND_RETURN_UNLESS(out_dim_order.size() == dim_order_ref.size());
4141
for (size_t i = 0; i < dim_order_ref.size(); i++) {
42-
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == dim_order_ref[i]);
42+
ET_LOG_AND_RETURN_UNLESS(out_dim_order[i] == dim_order_ref[i]);
4343
}
4444
} else { // dim_order is not set, out tensor should be contiguous memory
4545
// format
46-
ET_LOG_AND_RETURN_IF_FALSE(
46+
ET_LOG_AND_RETURN_UNLESS(
4747
is_contiguous_dim_order(out_dim_order.data(), out_dim_order.size()));
4848
}
4949
return true;

0 commit comments

Comments
 (0)