Skip to content

Commit a5c1890

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
parallel_for should return true if precondition fails (#2240)
Summary: Pull Request resolved: #2240 Don't crash, but log error and return false. Reviewed By: SS-JIA Differential Revision: D54505636 fbshipit-source-id: 1cdd2861fbae5bb355ef7bd61eb34c03f35418c5
1 parent 69bf18b commit a5c1890

File tree

5 files changed

+39
-40
lines changed

5 files changed

+39
-40
lines changed

extension/parallel/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ def define_common_targets():
2424
deps = [
2525
"//executorch/backends/xnnpack/threadpool:threadpool",
2626
"//executorch/runtime/core:core",
27+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
2728
],
2829
)

extension/parallel/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ def define_common_targets():
1414
],
1515
deps = [
1616
"//executorch/extension/parallel:thread_parallel",
17+
"//executorch/runtime/platform:platform",
1718
],
1819
)

extension/parallel/test/thread_parallel_test.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <mutex>
1313

1414
#include <executorch/extension/parallel/thread_parallel.h>
15-
#include <executorch/test/utils/DeathTest.h>
15+
#include <executorch/runtime/platform/platform.h>
1616

1717
using namespace ::testing;
1818

@@ -49,19 +49,19 @@ class ParallelTest : public ::testing::Test {
4949
};
5050

5151
TEST_F(ParallelTest, TestAllInvoked) {
52-
parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
52+
EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
5353
this->RunTask(begin, end);
54-
});
54+
}));
5555

5656
for (int64_t i = 0; i < 10; ++i) {
5757
EXPECT_EQ(data_[i], i);
5858
}
5959
}
6060

6161
TEST_F(ParallelTest, TestAllInvokedWithMutex) {
62-
parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
62+
EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
6363
this->RunExclusiveTask(begin, end);
64-
});
64+
}));
6565

6666
int expected_sum = 0;
6767
for (int64_t i = 0; i < 10; ++i) {
@@ -72,13 +72,10 @@ TEST_F(ParallelTest, TestAllInvokedWithMutex) {
7272
}
7373

7474
TEST_F(ParallelTest, TestInvalidRange) {
75-
ET_EXPECT_DEATH(
76-
{
77-
parallel_for(10, 0, 1, [this](int64_t begin, int64_t end) {
78-
this->RunExclusiveTask(begin, end);
79-
});
80-
},
81-
"");
75+
et_pal_init();
76+
EXPECT_FALSE(parallel_for(10, 0, 1, [this](int64_t begin, int64_t end) {
77+
this->RunExclusiveTask(begin, end);
78+
}));
8279

8380
for (int64_t i = 0; i < 10; ++i) {
8481
EXPECT_EQ(data_[i], 0);
@@ -87,13 +84,10 @@ TEST_F(ParallelTest, TestInvalidRange) {
8784
}
8885

8986
TEST_F(ParallelTest, TestInvalidRange2) {
90-
ET_EXPECT_DEATH(
91-
{
92-
parallel_for(6, 5, 1, [this](int64_t begin, int64_t end) {
93-
this->RunExclusiveTask(begin, end);
94-
});
95-
},
96-
"");
87+
et_pal_init();
88+
EXPECT_FALSE(parallel_for(6, 5, 1, [this](int64_t begin, int64_t end) {
89+
this->RunExclusiveTask(begin, end);
90+
}));
9791

9892
for (int64_t i = 0; i < 10; ++i) {
9993
EXPECT_EQ(data_[i], 0);
@@ -102,9 +96,9 @@ TEST_F(ParallelTest, TestInvalidRange2) {
10296
}
10397

10498
TEST_F(ParallelTest, TestInvokePartialFromBeginning) {
105-
parallel_for(0, 5, 1, [this](int64_t begin, int64_t end) {
99+
EXPECT_TRUE(parallel_for(0, 5, 1, [this](int64_t begin, int64_t end) {
106100
this->RunTask(begin, end);
107-
});
101+
}));
108102

109103
for (int64_t i = 0; i < 5; ++i) {
110104
EXPECT_EQ(data_[i], i);
@@ -115,9 +109,9 @@ TEST_F(ParallelTest, TestInvokePartialFromBeginning) {
115109
}
116110

117111
TEST_F(ParallelTest, TestInvokePartialToEnd) {
118-
parallel_for(5, 10, 1, [this](int64_t begin, int64_t end) {
112+
EXPECT_TRUE(parallel_for(5, 10, 1, [this](int64_t begin, int64_t end) {
119113
this->RunTask(begin, end);
120-
});
114+
}));
121115

122116
for (int64_t i = 0; i < 5; ++i) {
123117
EXPECT_EQ(data_[i], 0);
@@ -128,9 +122,9 @@ TEST_F(ParallelTest, TestInvokePartialToEnd) {
128122
}
129123

130124
TEST_F(ParallelTest, TestInvokePartialMiddle) {
131-
parallel_for(2, 8, 1, [this](int64_t begin, int64_t end) {
125+
EXPECT_TRUE(parallel_for(2, 8, 1, [this](int64_t begin, int64_t end) {
132126
this->RunTask(begin, end);
133-
});
127+
}));
134128

135129
for (int64_t i = 0; i < 2; ++i) {
136130
EXPECT_EQ(data_[i], 0);
@@ -144,19 +138,19 @@ TEST_F(ParallelTest, TestInvokePartialMiddle) {
144138
}
145139

146140
TEST_F(ParallelTest, TestChunkSize2) {
147-
parallel_for(0, 10, 2, [this](int64_t begin, int64_t end) {
141+
EXPECT_TRUE(parallel_for(0, 10, 2, [this](int64_t begin, int64_t end) {
148142
this->RunTask(begin, end);
149-
});
143+
}));
150144

151145
for (int64_t i = 0; i < 10; ++i) {
152146
EXPECT_EQ(data_[i], i);
153147
}
154148
}
155149

156150
TEST_F(ParallelTest, TestChunkSize2Middle) {
157-
parallel_for(3, 8, 2, [this](int64_t begin, int64_t end) {
151+
EXPECT_TRUE(parallel_for(3, 8, 2, [this](int64_t begin, int64_t end) {
158152
this->RunTask(begin, end);
159-
});
153+
}));
160154

161155
for (int64_t i = 0; i < 3; ++i) {
162156
EXPECT_EQ(data_[i], 0);
@@ -170,29 +164,29 @@ TEST_F(ParallelTest, TestChunkSize2Middle) {
170164
}
171165

172166
TEST_F(ParallelTest, TestChunkSize3) {
173-
parallel_for(0, 10, 3, [this](int64_t begin, int64_t end) {
167+
EXPECT_TRUE(parallel_for(0, 10, 3, [this](int64_t begin, int64_t end) {
174168
this->RunTask(begin, end);
175-
});
169+
}));
176170

177171
for (int64_t i = 0; i < 10; ++i) {
178172
EXPECT_EQ(data_[i], i);
179173
}
180174
}
181175

182176
TEST_F(ParallelTest, TestChunkSize6) {
183-
parallel_for(0, 10, 6, [this](int64_t begin, int64_t end) {
177+
EXPECT_TRUE(parallel_for(0, 10, 6, [this](int64_t begin, int64_t end) {
184178
this->RunTask(begin, end);
185-
});
179+
}));
186180

187181
for (int64_t i = 0; i < 10; ++i) {
188182
EXPECT_EQ(data_[i], i);
189183
}
190184
}
191185

192186
TEST_F(ParallelTest, TestChunkSizeTooLarge) {
193-
parallel_for(0, 10, 11, [this](int64_t begin, int64_t end) {
187+
EXPECT_TRUE(parallel_for(0, 10, 11, [this](int64_t begin, int64_t end) {
194188
this->RunTask(begin, end);
195-
});
189+
}));
196190

197191
for (int64_t i = 0; i < 10; ++i) {
198192
EXPECT_EQ(data_[i], i);

extension/parallel/thread_parallel.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
1212
#include <executorch/extension/parallel/thread_parallel.h>
13+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1314
#include <executorch/runtime/platform/assert.h>
1415

1516
namespace torch::executor {
@@ -34,14 +35,14 @@ calc_num_tasks_and_chunk_size(int64_t begin, int64_t end, int64_t grain_size) {
3435
return std::make_tuple(num_tasks, chunk_size);
3536
}
3637

37-
void parallel_for(
38+
bool parallel_for(
3839
const int64_t begin,
3940
const int64_t end,
4041
const int64_t grain_size,
4142
const std::function<void(int64_t, int64_t)>& f) {
42-
ET_CHECK_MSG(begin >= 0 && end >= 0, "Begin and end should be non-negative");
43-
ET_CHECK_MSG(end >= begin, "end should be greater than or equal to begin");
44-
ET_CHECK_MSG(grain_size > 0, "grain_size should be positive");
43+
ET_LOG_AND_RETURN_IF_FALSE(begin >= 0 && end >= 0);
44+
ET_LOG_AND_RETURN_IF_FALSE(end >= begin);
45+
ET_LOG_AND_RETURN_IF_FALSE(grain_size > 0);
4546
int64_t num_tasks = 0, chunk_size = 0;
4647
std::tie(num_tasks, chunk_size) =
4748
calc_num_tasks_and_chunk_size(begin, end, grain_size);
@@ -57,6 +58,7 @@ void parallel_for(
5758
// Per protocol from threadpool (pthreadpool), when this returns, all tasks
5859
// are executed, so this is synchronous.
5960
get_threadpool()->run(task, num_tasks);
61+
return true;
6062
}
6163

6264
} // namespace torch::executor

extension/parallel/thread_parallel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ namespace torch::executor {
2323
* described below
2424
* f: user function applied in parallel to the chunks, signature:
2525
* void f(int64_t begin, int64_t end)
26+
* Returns true if all work items are processed successfully, false otherwise
2627
*
2728
* Warning: parallel_for does NOT copy thread local states from the current
2829
* thread to the worker threads. Users need to protect the access to captured
2930
* data if they mutate them in f.
3031
*/
31-
void parallel_for(
32+
bool parallel_for(
3233
const int64_t begin,
3334
const int64_t end,
3435
const int64_t grain_size,

0 commit comments

Comments
 (0)