Skip to content

Commit 566209d

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Implementation thread parallel with threadpool (#2173)
Summary: Pull Request resolved: #2173 Use the ET threadpool (with underlying pthreadpool) to provide `parallel_for` functionality Reviewed By: kimishpatel Differential Revision: D54335940 fbshipit-source-id: 0865d0c76d1f16c325da8c13656fa955d6a48ade
1 parent 568673e commit 566209d

File tree

7 files changed

+363
-0
lines changed

7 files changed

+363
-0
lines changed

extension/parallel/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

extension/parallel/targets.bzl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
for aten_mode in (True, False):
11+
aten_suffix = ("_aten" if aten_mode else "")
12+
13+
runtime.cxx_library(
14+
name = "thread_parallel" + aten_suffix,
15+
srcs = [
16+
"thread_parallel.cpp",
17+
],
18+
exported_headers = [
19+
"thread_parallel.h",
20+
],
21+
visibility = [
22+
"//executorch/...",
23+
],
24+
deps = [
25+
"//executorch/backends/xnnpack/threadpool:threadpool",
26+
"//executorch/runtime/core:core",
27+
],
28+
)

extension/parallel/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

extension/parallel/test/targets.bzl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.cxx_test(
11+
name = "thread_parallel_test",
12+
srcs = [
13+
"thread_parallel_test.cpp",
14+
],
15+
deps = [
16+
"//executorch/extension/parallel:thread_parallel",
17+
],
18+
)
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <array>
12+
#include <mutex>
13+
14+
#include <executorch/extension/parallel/thread_parallel.h>
15+
#include <executorch/test/utils/DeathTest.h>
16+
17+
using namespace ::testing;
18+
19+
namespace torch::executor {
20+
21+
class ParallelTest : public ::testing::Test {
22+
protected:
23+
void SetUp() override {
24+
data_.fill(0);
25+
sum_of_all_elements_ = 0;
26+
}
27+
28+
void RunTask(int64_t begin, int64_t end) {
29+
for (int64_t j = begin; j < end; ++j) {
30+
// Check that we haven't written to this index before
31+
EXPECT_EQ(data_[j], 0);
32+
data_[j] = j;
33+
}
34+
}
35+
36+
void RunExclusiveTask(int64_t begin, int64_t end) {
37+
for (int64_t j = begin; j < end; ++j) {
38+
// Check that we haven't written to this index before
39+
EXPECT_EQ(data_[j], 0);
40+
std::lock_guard<std::mutex> lock(mutex_);
41+
data_[j] = j;
42+
sum_of_all_elements_ += data_[j];
43+
}
44+
}
45+
46+
std::array<int, 10> data_;
47+
std::mutex mutex_;
48+
int sum_of_all_elements_;
49+
};
50+
51+
TEST_F(ParallelTest, TestAllInvoked) {
52+
parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
53+
this->RunTask(begin, end);
54+
});
55+
56+
for (int64_t i = 0; i < 10; ++i) {
57+
EXPECT_EQ(data_[i], i);
58+
}
59+
}
60+
61+
TEST_F(ParallelTest, TestAllInvokedWithMutex) {
62+
parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
63+
this->RunExclusiveTask(begin, end);
64+
});
65+
66+
int expected_sum = 0;
67+
for (int64_t i = 0; i < 10; ++i) {
68+
EXPECT_EQ(data_[i], i);
69+
expected_sum += i;
70+
}
71+
EXPECT_EQ(sum_of_all_elements_, expected_sum);
72+
}
73+
74+
TEST_F(ParallelTest, TestInvalidRange) {
75+
ET_EXPECT_DEATH(
76+
{
77+
parallel_for(10, 0, 1, [this](int64_t begin, int64_t end) {
78+
this->RunExclusiveTask(begin, end);
79+
});
80+
},
81+
"");
82+
83+
for (int64_t i = 0; i < 10; ++i) {
84+
EXPECT_EQ(data_[i], 0);
85+
}
86+
EXPECT_EQ(sum_of_all_elements_, 0);
87+
}
88+
89+
TEST_F(ParallelTest, TestInvalidRange2) {
90+
ET_EXPECT_DEATH(
91+
{
92+
parallel_for(6, 5, 1, [this](int64_t begin, int64_t end) {
93+
this->RunExclusiveTask(begin, end);
94+
});
95+
},
96+
"");
97+
98+
for (int64_t i = 0; i < 10; ++i) {
99+
EXPECT_EQ(data_[i], 0);
100+
}
101+
EXPECT_EQ(sum_of_all_elements_, 0);
102+
}
103+
104+
TEST_F(ParallelTest, TestInvokePartialFromBeginning) {
105+
parallel_for(0, 5, 1, [this](int64_t begin, int64_t end) {
106+
this->RunTask(begin, end);
107+
});
108+
109+
for (int64_t i = 0; i < 5; ++i) {
110+
EXPECT_EQ(data_[i], i);
111+
}
112+
for (int64_t i = 5; i < 10; ++i) {
113+
EXPECT_EQ(data_[i], 0);
114+
}
115+
}
116+
117+
TEST_F(ParallelTest, TestInvokePartialToEnd) {
118+
parallel_for(5, 10, 1, [this](int64_t begin, int64_t end) {
119+
this->RunTask(begin, end);
120+
});
121+
122+
for (int64_t i = 0; i < 5; ++i) {
123+
EXPECT_EQ(data_[i], 0);
124+
}
125+
for (int64_t i = 5; i < 10; ++i) {
126+
EXPECT_EQ(data_[i], i);
127+
}
128+
}
129+
130+
TEST_F(ParallelTest, TestInvokePartialMiddle) {
131+
parallel_for(2, 8, 1, [this](int64_t begin, int64_t end) {
132+
this->RunTask(begin, end);
133+
});
134+
135+
for (int64_t i = 0; i < 2; ++i) {
136+
EXPECT_EQ(data_[i], 0);
137+
}
138+
for (int64_t i = 2; i < 8; ++i) {
139+
EXPECT_EQ(data_[i], i);
140+
}
141+
for (int64_t i = 8; i < 10; ++i) {
142+
EXPECT_EQ(data_[i], 0);
143+
}
144+
}
145+
146+
TEST_F(ParallelTest, TestChunkSize2) {
147+
parallel_for(0, 10, 2, [this](int64_t begin, int64_t end) {
148+
this->RunTask(begin, end);
149+
});
150+
151+
for (int64_t i = 0; i < 10; ++i) {
152+
EXPECT_EQ(data_[i], i);
153+
}
154+
}
155+
156+
TEST_F(ParallelTest, TestChunkSize2Middle) {
157+
parallel_for(3, 8, 2, [this](int64_t begin, int64_t end) {
158+
this->RunTask(begin, end);
159+
});
160+
161+
for (int64_t i = 0; i < 3; ++i) {
162+
EXPECT_EQ(data_[i], 0);
163+
}
164+
for (int64_t i = 3; i < 8; ++i) {
165+
EXPECT_EQ(data_[i], i);
166+
}
167+
for (int64_t i = 8; i < 10; ++i) {
168+
EXPECT_EQ(data_[i], 0);
169+
}
170+
}
171+
172+
TEST_F(ParallelTest, TestChunkSize3) {
173+
parallel_for(0, 10, 3, [this](int64_t begin, int64_t end) {
174+
this->RunTask(begin, end);
175+
});
176+
177+
for (int64_t i = 0; i < 10; ++i) {
178+
EXPECT_EQ(data_[i], i);
179+
}
180+
}
181+
182+
TEST_F(ParallelTest, TestChunkSize6) {
183+
parallel_for(0, 10, 6, [this](int64_t begin, int64_t end) {
184+
this->RunTask(begin, end);
185+
});
186+
187+
for (int64_t i = 0; i < 10; ++i) {
188+
EXPECT_EQ(data_[i], i);
189+
}
190+
}
191+
192+
TEST_F(ParallelTest, TestChunkSizeTooLarge) {
193+
parallel_for(0, 10, 11, [this](int64_t begin, int64_t end) {
194+
this->RunTask(begin, end);
195+
});
196+
197+
for (int64_t i = 0; i < 10; ++i) {
198+
EXPECT_EQ(data_[i], i);
199+
}
200+
}
201+
202+
} // namespace torch::executor
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <tuple>
10+
11+
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
12+
#include <executorch/extension/parallel/thread_parallel.h>
13+
#include <executorch/runtime/platform/assert.h>
14+
15+
namespace torch::executor {
16+
17+
using namespace torch::executorch::threadpool;
18+
19+
inline int64_t divup(int64_t x, int64_t y) {
20+
return (x + y - 1) / y;
21+
}
22+
23+
inline std::tuple<int64_t, int64_t>
24+
calc_num_tasks_and_chunk_size(int64_t begin, int64_t end, int64_t grain_size) {
25+
if ((end - begin) < grain_size) {
26+
return std::make_tuple(1, std::max((int64_t)0, end - begin));
27+
}
28+
// Choose number of tasks based on grain size and number of threads.
29+
int64_t chunk_size =
30+
divup((end - begin), get_threadpool()->get_thread_count());
31+
// Make sure each task is at least grain_size size.
32+
chunk_size = std::max(grain_size, chunk_size);
33+
int64_t num_tasks = divup((end - begin), chunk_size);
34+
return std::make_tuple(num_tasks, chunk_size);
35+
}
36+
37+
void parallel_for(
38+
const int64_t begin,
39+
const int64_t end,
40+
const int64_t grain_size,
41+
const std::function<void(int64_t, int64_t)>& f) {
42+
ET_CHECK_MSG(begin >= 0 && end >= 0, "Begin and end should be non-negative");
43+
ET_CHECK_MSG(end >= begin, "end should be greater than or equal to begin");
44+
ET_CHECK_MSG(grain_size > 0, "grain_size should be positive");
45+
int64_t num_tasks = 0, chunk_size = 0;
46+
std::tie(num_tasks, chunk_size) =
47+
calc_num_tasks_and_chunk_size(begin, end, grain_size);
48+
49+
auto task = [f, begin, end, chunk_size](size_t task_id) {
50+
int64_t local_start = begin + static_cast<int64_t>(task_id) * chunk_size;
51+
if (local_start < end) {
52+
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
53+
f(local_start, local_end);
54+
}
55+
};
56+
57+
// Per protocol from threadpool (pthreadpool), when this returns, all tasks
58+
// are executed, so this is synchronous.
59+
get_threadpool()->run(task, num_tasks);
60+
}
61+
62+
} // namespace torch::executor

extension/parallel/thread_parallel.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
// @nolint PATTERNLINT Ok to use stdlib for this optional library
13+
#include <functional>
14+
15+
namespace torch::executor {
16+
17+
/**
18+
* A helper to run function in parallel.
19+
*
20+
* begin, end: describe the extent of the workitems via first and last workitem
21+
* to be processed
22+
* grain_size: number of workitems processed by user callback which is
23+
* described below
24+
* f: user function applied in parallel to the chunks, signature:
25+
* void f(int64_t begin, int64_t end)
26+
*
27+
* Warning: parallel_for does NOT copy thread local states from the current
28+
* thread to the worker threads. Users need to protect the access to captured
29+
* data if they mutate them in f.
30+
*/
31+
void parallel_for(
32+
const int64_t begin,
33+
const int64_t end,
34+
const int64_t grain_size,
35+
const std::function<void(int64_t, int64_t)>& f);
36+
37+
} // namespace torch::executor

0 commit comments

Comments
 (0)