Skip to content

Commit ac98c33

Browse files
authored
[SYCL][COMPAT] Fix slow math_extend tests (#14599)
The `math_extend*cpp` tests were very slow despite not doing a huge amount of work. This PR accelerates these tests with the following changes: - Split into several different kernels - Replace `parallel_for` with `single_task` - Avoid branching return statements - Avoid storing function names as `char` arrays in kernel - Avoid `sycl::stream` for output Locally, the `math_extend_v_4.cpp` tests are about 5x faster. The other 2 are around 2x faster. --------- Signed-off-by: Joe Todd <[email protected]>
1 parent b29e416 commit ac98c33

File tree

3 files changed

+246
-452
lines changed

3 files changed

+246
-452
lines changed

sycl/test-e2e/syclcompat/math/math_extend.cpp

Lines changed: 68 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@
4444

4545
#define CHECK(S, REF) \
4646
{ \
47+
++test_id; \
4748
auto ret = S; \
4849
if (ret != REF) { \
49-
return {#S, REF}; \
50+
errc = test_id; \
5051
} \
5152
}
5253

@@ -56,7 +57,9 @@ const auto UINT32MAX = std::numeric_limits<uint32_t>::max();
5657
const auto UINT32MIN = std::numeric_limits<uint32_t>::min();
5758
const int b = 4, c = 5, d = 6;
5859

59-
std::pair<const char *, int> vadd() {
60+
int vadd() {
61+
int errc{};
62+
int test_id{};
6063
CHECK(syclcompat::extend_add<int32_t>(3, 4), 7);
6164
CHECK(syclcompat::extend_add<uint32_t>(b, c), 9);
6265
CHECK(syclcompat::extend_add_sat<int32_t>(b, INT32MAX), INT32MAX);
@@ -65,10 +68,12 @@ std::pair<const char *, int> vadd() {
6568
CHECK(syclcompat::extend_add_sat<int32_t>(b, c, -20, sycl::minimum<>()), -20);
6669
CHECK(syclcompat::extend_add_sat<int32_t>(b, (-33), 9, sycl::maximum<>()), 9);
6770

68-
return {nullptr, 0};
71+
return errc;
6972
}
7073

71-
std::pair<const char *, int> vsub() {
74+
int vsub() {
75+
int errc{};
76+
int test_id{};
7277
CHECK(syclcompat::extend_sub<int32_t>(3, 4), -1);
7378
CHECK(syclcompat::extend_sub<uint32_t>(c, b), 1);
7479
CHECK(syclcompat::extend_sub_sat<int32_t>(10, INT32MIN), INT32MAX);
@@ -78,10 +83,12 @@ std::pair<const char *, int> vsub() {
7883
CHECK(syclcompat::extend_sub_sat<int32_t>(b, (-33), 9, sycl::maximum<>()),
7984
37);
8085

81-
return {nullptr, 0};
86+
return errc;
8287
}
8388

84-
std::pair<const char *, int> vabsdiff() {
89+
int vabsdiff() {
90+
int errc{};
91+
int test_id{};
8592
CHECK(syclcompat::extend_absdiff<int32_t>(3, 4), 1);
8693
CHECK(syclcompat::extend_absdiff<uint32_t>(c, b), 1);
8794
CHECK(syclcompat::extend_absdiff_sat<int32_t>(10, INT32MIN), INT32MAX);
@@ -92,10 +99,12 @@ std::pair<const char *, int> vabsdiff() {
9299
CHECK(syclcompat::extend_absdiff_sat<int32_t>(b, (-33), 9, sycl::maximum<>()),
93100
37);
94101

95-
return {nullptr, 0};
102+
return errc;
96103
}
97104

98-
std::pair<const char *, int> vmin() {
105+
int vmin() {
106+
int errc{};
107+
int test_id{};
99108
CHECK(syclcompat::extend_min<int32_t>(3, 4), 3);
100109
CHECK(syclcompat::extend_min<uint32_t>(c, b), 4);
101110
CHECK(syclcompat::extend_min_sat<int32_t>(UINT32MAX, 1), 1);
@@ -104,10 +113,12 @@ std::pair<const char *, int> vmin() {
104113
CHECK(syclcompat::extend_min_sat<int32_t>(b, c, -20, sycl::minimum<>()), -20);
105114
CHECK(syclcompat::extend_min_sat<int32_t>(b, (-33), 9, sycl::maximum<>()), 9);
106115

107-
return {nullptr, 0};
116+
return errc;
108117
}
109118

110-
std::pair<const char *, int> vmax() {
119+
int vmax() {
120+
int errc{};
121+
int test_id{};
111122
CHECK(syclcompat::extend_max<int32_t>(3, 4), 4);
112123
CHECK(syclcompat::extend_max<uint32_t>(c, b), 5);
113124
CHECK(syclcompat::extend_max_sat<int32_t>(UINT32MAX, 1), INT32MAX);
@@ -116,7 +127,7 @@ std::pair<const char *, int> vmax() {
116127
CHECK(syclcompat::extend_max_sat<int32_t>(b, c, -20, sycl::minimum<>()), -20);
117128
CHECK(syclcompat::extend_max_sat<int32_t>(b, (-33), 9, sycl::maximum<>()), 9);
118129

119-
return {nullptr, 0};
130+
return errc;
120131
}
121132

122133
template <typename Tp> struct scale {
@@ -127,7 +138,9 @@ template <typename Tp> struct noop {
127138
Tp operator()(Tp val, Tp /*scaler*/) { return val; }
128139
};
129140

130-
std::pair<const char *, int> shl_clamp() {
141+
int shl_clamp() {
142+
int errc{};
143+
int test_id{};
131144
CHECK(syclcompat::extend_shl_clamp<int32_t>(3, 4), 48);
132145
CHECK(syclcompat::extend_shl_clamp<int32_t>(6, 33), 0);
133146
CHECK(syclcompat::extend_shl_clamp<int32_t>(3, 4, 4, scale<int32_t>()), 192);
@@ -139,10 +152,12 @@ std::pair<const char *, int> shl_clamp() {
139152
CHECK(syclcompat::extend_shl_sat_clamp<int8_t>(9, 5, -1, noop<int8_t>()),
140153
127);
141154

142-
return {nullptr, 0};
155+
return errc;
143156
}
144157

145-
std::pair<const char *, int> shl_wrap() {
158+
int shl_wrap() {
159+
int errc{};
160+
int test_id{};
146161
CHECK(syclcompat::extend_shl_wrap<int32_t>(3, 4), 48);
147162
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 32), 6);
148163
CHECK(syclcompat::extend_shl_wrap<int32_t>(6, 33), 12);
@@ -155,10 +170,12 @@ std::pair<const char *, int> shl_wrap() {
155170
-127);
156171
CHECK(syclcompat::extend_shl_sat_wrap<int8_t>(9, 5, -1, noop<int8_t>()), 127);
157172

158-
return {nullptr, 0};
173+
return errc;
159174
}
160175

161-
std::pair<const char *, int> shr_clamp() {
176+
int shr_clamp() {
177+
int errc{};
178+
int test_id{};
162179
CHECK(syclcompat::extend_shr_clamp<int32_t>(128, 5), 4);
163180
CHECK(syclcompat::extend_shr_clamp<int32_t>(INT32MAX, 33), 0);
164181
CHECK(syclcompat::extend_shr_clamp<int32_t>(128, 5, 4, scale<int32_t>()), 16);
@@ -170,10 +187,12 @@ std::pair<const char *, int> shr_clamp() {
170187
CHECK(syclcompat::extend_shr_sat_clamp<int8_t>(512, 1, -1, noop<int8_t>()),
171188
127);
172189

173-
return {nullptr, 0};
190+
return errc;
174191
}
175192

176-
std::pair<const char *, int> shr_wrap() {
193+
int shr_wrap() {
194+
int errc{};
195+
int test_id{};
177196
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 5), 4);
178197
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 32), 128);
179198
CHECK(syclcompat::extend_shr_wrap<int32_t>(128, 33), 64);
@@ -187,106 +206,43 @@ std::pair<const char *, int> shr_wrap() {
187206
CHECK(syclcompat::extend_shr_sat_wrap<int8_t>(512, 1, -1, noop<int8_t>()),
188207
127);
189208

190-
return {nullptr, 0};
209+
return errc;
191210
}
192211

193-
void test(const sycl::stream &s, int *ec) {
194-
{
195-
auto res = vadd();
196-
if (res.first) {
197-
s << res.first << " = " << res.second << " check failed!\n";
198-
*ec = 1;
199-
return;
200-
}
201-
s << "vadd check passed!\n";
202-
}
203-
{
204-
auto res = vsub();
205-
if (res.first) {
206-
s << res.first << " = " << res.second << " check failed!\n";
207-
*ec = 2;
208-
return;
209-
}
210-
s << "vsub check passed!\n";
211-
}
212-
{
213-
auto res = vabsdiff();
214-
if (res.first) {
215-
s << res.first << " = " << res.second << " check failed!\n";
216-
*ec = 3;
217-
return;
218-
}
219-
s << "vabsdiff check passed!\n";
220-
}
221-
{
222-
auto res = vmin();
223-
if (res.first) {
224-
s << res.first << " = " << res.second << " check failed!\n";
225-
*ec = 4;
226-
return;
227-
}
228-
s << "vmin check passed!\n";
229-
}
230-
{
231-
auto res = vmax();
232-
if (res.first) {
233-
s << res.first << " = " << res.second << " check failed!\n";
234-
*ec = 5;
235-
return;
236-
}
237-
s << "vmax check passed!\n";
238-
}
239-
{
240-
auto res = shl_clamp();
241-
if (res.first) {
242-
s << res.first << " = " << res.second << " check failed!\n";
243-
*ec = 6;
244-
return;
245-
}
246-
s << "shl_clamp check passed!\n";
247-
}
248-
{
249-
auto res = shl_wrap();
250-
if (res.first) {
251-
s << res.first << " = " << res.second << " check failed!\n";
252-
*ec = 7;
253-
return;
254-
}
255-
s << "shl_wrap check passed!\n";
256-
}
257-
{
258-
auto res = shr_clamp();
259-
if (res.first) {
260-
s << res.first << " = " << res.second << " check failed!\n";
261-
*ec = 8;
262-
return;
263-
}
264-
s << "shr_clamp check passed!\n";
265-
}
266-
{
267-
auto res = shr_wrap();
268-
if (res.first) {
269-
s << res.first << " = " << res.second << " check failed!\n";
270-
*ec = 9;
271-
return;
272-
}
273-
s << "shr_wrap check passed!\n";
274-
}
275-
*ec = 0;
276-
}
212+
template <auto F> void test_fn(sycl::queue q, int *ec) {
213+
std::cout << __PRETTY_FUNCTION__ << std::endl;
277214

278-
int main() {
279-
sycl::queue q = syclcompat::get_default_queue();
280-
int *ec = syclcompat::malloc<int>(1);
281-
syclcompat::fill<int>(ec, 0, 1);
282215
q.submit([&](sycl::handler &cgh) {
283-
sycl::stream out(1024, 256, cgh);
284-
cgh.parallel_for(1, [=](sycl::item<1> it) { test(out, ec); });
216+
cgh.single_task([=]() {
217+
auto res = F();
218+
if(res != 0) *ec = res;
219+
});
285220
});
221+
int ec_h{};
222+
syclcompat::memcpy<int>(&ec_h, ec, 1, q);
286223
q.wait_and_throw();
287224

288-
int ec_h;
289-
syclcompat::memcpy<int>(&ec_h, ec, 1);
225+
if (ec_h != 0) {
226+
std::cout << "Test " << ec_h << " failed." << std::endl;
227+
syclcompat::free(ec, q);
228+
assert(false);
229+
}
230+
}
290231

291-
return ec_h;
232+
int main() {
233+
sycl::queue q = syclcompat::get_default_queue();
234+
int *ec = syclcompat::malloc<int>(1, q);
235+
syclcompat::fill<int>(ec, 0, 1, q);
236+
237+
test_fn<vadd>(q, ec);
238+
test_fn<vsub>(q, ec);
239+
test_fn<vabsdiff>(q, ec);
240+
test_fn<vmin>(q, ec);
241+
test_fn<vmax>(q, ec);
242+
test_fn<shl_clamp>(q, ec);
243+
test_fn<shl_wrap>(q, ec);
244+
test_fn<shr_clamp>(q, ec);
245+
test_fn<shr_wrap>(q, ec);
246+
247+
syclcompat::free(ec, q);
292248
}

0 commit comments

Comments
 (0)