Skip to content

Commit 7293dca

Browse files
committed
[SYCL][COMPAT] match_*_over_sub_group tests. Fixed match_all to match the documented behavior.
1 parent 32bc5bb commit 7293dca

File tree

3 files changed

+198
-1
lines changed

3 files changed

+198
-1
lines changed

sycl/include/syclcompat/util.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ unsigned int match_all_over_sub_group(sycl::sub_group g, unsigned member_mask,
372372
sycl::plus<>());
373373
bool all_equal = (reduce_result == member_mask);
374374
*pred = is_participate & all_equal;
375-
return all_equal * member_mask;
375+
return (is_participate & all_equal) * member_mask;
376376
}
377377

378378
namespace experimental {
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCLcompat API
16+
*
17+
* util_match_all_over_group.cpp
18+
*
19+
* Description:
20+
* util_match_all_over_group tests
21+
**************************************************************************/
22+
23+
// The original source was under the license below:
24+
// ====------ UtilSelectFromSubGroup.cpp---------- -*- C++ -* ----===////
25+
//
26+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27+
// See https://llvm.org/LICENSE.txt for license information.
28+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
29+
//
30+
//
31+
// ===----------------------------------------------------------------------===//
32+
33+
// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
34+
// RUN: %{run} %t.out
35+
36+
#include <sycl/sycl.hpp>
37+
#include <syclcompat.hpp>
38+
39+
constexpr unsigned int NUM_TESTS = 3;
40+
constexpr unsigned int SUBGROUP_SIZE = 16;
41+
constexpr unsigned int DATA_SIZE = NUM_TESTS * SUBGROUP_SIZE;
42+
43+
void test_select_from_sub_group() {
44+
std::cout << __PRETTY_FUNCTION__ << std::endl;
45+
46+
constexpr syclcompat::dim3 grid{1};
47+
constexpr syclcompat::dim3 threads{SUBGROUP_SIZE};
48+
49+
unsigned int input[DATA_SIZE] = {
50+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, // #1
51+
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, // #2
52+
0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; // #3
53+
unsigned int output[DATA_SIZE];
54+
int pred[DATA_SIZE];
55+
unsigned int *d_input = syclcompat::malloc<unsigned int>(DATA_SIZE);
56+
unsigned int *d_output = syclcompat::malloc<unsigned int>(DATA_SIZE);
57+
int *d_pred = syclcompat::malloc<int>(DATA_SIZE);
58+
59+
unsigned int member_mask = 0x00FF;
60+
unsigned int expected[DATA_SIZE] = {
61+
0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF,
62+
0, 0, 0, 0, 0, 0, 0, 0, // #1
63+
0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF,
64+
0, 0, 0, 0, 0, 0, 0, 0, // #2
65+
0, 0, 0, 0, 0, 0, 0, 0,
66+
0, 0, 0, 0, 0, 0, 0, 0, // #3
67+
};
68+
unsigned int expected_pred[DATA_SIZE] = {
69+
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, // #1
70+
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, // #2
71+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // #3
72+
};
73+
74+
syclcompat::memcpy<unsigned int>(d_input, input, DATA_SIZE);
75+
syclcompat::memset(d_output, 0, DATA_SIZE * sizeof(unsigned int));
76+
syclcompat::memset(d_pred, 1, DATA_SIZE * sizeof(int));
77+
78+
sycl::queue q = syclcompat::get_default_queue();
79+
q.parallel_for(
80+
sycl::nd_range<1>(threads.size(), threads.size()),
81+
[=](sycl::nd_item<1> item) [[intel::reqd_sub_group_size(SUBGROUP_SIZE)]] {
82+
for (auto id = item.get_global_linear_id(); id < DATA_SIZE;
83+
id += SUBGROUP_SIZE)
84+
d_output[id] = syclcompat::match_all_over_sub_group(
85+
item.get_sub_group(), member_mask, d_input[id], &d_pred[id]);
86+
});
87+
q.wait_and_throw();
88+
syclcompat::memcpy<unsigned int>(output, d_output, DATA_SIZE);
89+
syclcompat::memcpy<int>(pred, d_pred, DATA_SIZE);
90+
91+
for (int i = 0; i < DATA_SIZE; ++i) {
92+
assert(output[i] == expected[i]);
93+
assert(pred[i] == expected_pred[i]);
94+
}
95+
96+
syclcompat::free(d_input);
97+
syclcompat::free(d_output);
98+
syclcompat::free(d_pred);
99+
}
100+
101+
int main() {
102+
test_select_from_sub_group();
103+
104+
return 0;
105+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCLcompat API
16+
*
17+
* util_match_any_over_group.cpp
18+
*
19+
* Description:
20+
* util_match_any_over_group tests
21+
**************************************************************************/
22+
23+
// The original source was under the license below:
24+
// ====------ UtilSelectFromSubGroup.cpp---------- -*- C++ -* ----===////
25+
//
26+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27+
// See https://llvm.org/LICENSE.txt for license information.
28+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
29+
//
30+
//
31+
// ===----------------------------------------------------------------------===//
32+
33+
// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
34+
// RUN: %{run} %t.out
35+
36+
#include <sycl/sycl.hpp>
37+
#include <syclcompat.hpp>
38+
39+
#define DATA_SIZE 64
40+
#define SUBGROUP_SIZE 16
41+
42+
void test_select_from_sub_group() {
43+
std::cout << __PRETTY_FUNCTION__ << std::endl;
44+
45+
constexpr syclcompat::dim3 grid{1};
46+
constexpr syclcompat::dim3 threads{DATA_SIZE};
47+
48+
unsigned int input[DATA_SIZE] = {
49+
0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1,
50+
1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
51+
3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
52+
unsigned int output[DATA_SIZE];
53+
unsigned int *d_input = syclcompat::malloc<unsigned int>(DATA_SIZE);
54+
unsigned int *d_output = syclcompat::malloc<unsigned int>(DATA_SIZE);
55+
56+
unsigned int member_mask = 0x0FFF;
57+
unsigned int expected[DATA_SIZE] = {
58+
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
59+
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
60+
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
61+
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
62+
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
63+
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
64+
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
65+
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
66+
};
67+
68+
syclcompat::memcpy<unsigned int>(d_input, input, DATA_SIZE);
69+
sycl::queue q = syclcompat::get_default_queue();
70+
q.parallel_for(
71+
sycl::nd_range<1>(grid.size() * threads.size(), threads.size()),
72+
[=](sycl::nd_item<1> item) [[intel::reqd_sub_group_size(SUBGROUP_SIZE)]] {
73+
auto id = item.get_global_linear_id();
74+
d_output[id] = syclcompat::match_any_over_sub_group(
75+
item.get_sub_group(), member_mask, d_input[id]);
76+
});
77+
q.wait_and_throw();
78+
syclcompat::memcpy<unsigned int>(output, d_output, DATA_SIZE);
79+
80+
for (int i = 0; i < DATA_SIZE; ++i) {
81+
assert(output[i] == expected[i]);
82+
}
83+
84+
syclcompat::free(d_input);
85+
syclcompat::free(d_output);
86+
}
87+
88+
int main() {
89+
test_select_from_sub_group();
90+
91+
return 0;
92+
}

0 commit comments

Comments
 (0)