Skip to content

Commit 0680e5c

Browse files
authored
[LIBCLC][PTX] Add group_ballot intrinsic (#5020)
This patch implements the `group_ballot` intrinsic for NVIDIA, it is currently only implemented for subgroups.
1 parent d3ab145 commit 0680e5c

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

libclc/ptx-nvidiacl/libspirv/SOURCES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ images/image_helpers.ll
8585
images/image.cl
8686
group/collectives_helpers.ll
8787
group/collectives.cl
88+
group/group_ballot.cl
8889
atomic/atomic_add.cl
8990
atomic/atomic_and.cl
9091
atomic/atomic_cmpxchg.cl

libclc/ptx-nvidiacl/libspirv/group/collectives.cl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "membermask.h"
10+
911
#include <spirv/spirv.h>
1012
#include <spirv/spirv_types.h>
1113

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "membermask.h"
10+
11+
#include <spirv/spirv.h>
12+
#include <spirv/spirv_types.h>
13+
14+
_CLC_DEF _CLC_CONVERGENT __clc_vec4_uint32_t
15+
_Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) {
16+
// only support subgroup for now
17+
if (flag != Subgroup) {
18+
__builtin_trap();
19+
__builtin_unreachable();
20+
}
21+
22+
// prepare result, we only support the ballot operation on 32 threads maximum
23+
// so we only need the first element to represent the final mask
24+
__clc_vec4_uint32_t res;
25+
res[1] = 0;
26+
res[2] = 0;
27+
res[3] = 0;
28+
29+
// compute thread mask
30+
unsigned threads = __clc__membermask();
31+
32+
// run the ballot operation
33+
res[0] = __nvvm_vote_ballot_sync(threads, predicate);
34+
35+
return res;
36+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PTX_NVIDIACL_MEMBERMASK_H
10+
#define PTX_NVIDIACL_MEMBERMASK_H
11+
12+
#include <spirv/spirv.h>
13+
#include <spirv/spirv_types.h>
14+
15+
_CLC_DEF _CLC_CONVERGENT uint __clc__membermask();
16+
17+
#endif

0 commit comments

Comments
 (0)