Skip to content

Commit cd2f85a

Browse files
authored
[mlir][NVVM] Add ops for vote all and any sync (#134309)
Add operations for `nvvm.vote.all.sync` and `nvvm.vote.any.sync` intrinsics similar to `nvvm.vote.ballot.sync`.
1 parent d62d15e commit cd2f85a

File tree

7 files changed

+87
-40
lines changed

7 files changed

+87
-40
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6616,7 +6616,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
66166616
mlir::Value arg1 =
66176617
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
66186618
return builder
6619-
.create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
6619+
.create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
6620+
mlir::NVVM::VoteSyncKind::ballot)
66206621
.getResult();
66216622
}
66226623

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ end subroutine
303303
! CHECK-LABEL: func.func @_QPtestvote()
304304
! CHECK: fir.call @llvm.nvvm.vote.all.sync
305305
! CHECK: fir.call @llvm.nvvm.vote.any.sync
306-
! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
306+
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
307307

308308
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
309309
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -808,15 +808,49 @@ def NVVM_ShflOp :
808808
let hasVerifier = 1;
809809
}
810810

811-
def NVVM_VoteBallotOp :
812-
NVVM_Op<"vote.ballot.sync">,
813-
Results<(outs LLVM_Type:$res)>,
814-
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
811+
def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
812+
def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
813+
def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
814+
def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
815+
816+
def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
817+
[VoteSyncKindAny, VoteSyncKindAll,
818+
VoteSyncKindBallot, VoteSyncKindUni]> {
819+
let genSpecializedAttr = 0;
820+
let cppNamespace = "::mlir::NVVM";
821+
}
822+
823+
def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
824+
825+
def NVVM_VoteSyncOp
826+
: NVVM_Op<"vote.sync">,
827+
Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
828+
Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
829+
let summary = "Vote across thread group";
830+
let description = [{
831+
The `vote.sync` op will cause executing thread to wait until all non-exited
832+
threads corresponding to membermask have executed `vote.sync` with the same
833+
qualifiers and same membermask value before resuming execution.
834+
835+
The vote operation kinds are:
836+
- `any`: True if source predicate is True for some thread in membermask.
837+
- `all`: True if source predicate is True for all non-exited threads in
838+
membermask.
839+
- `uni`: True if source predicate has the same value in all non-exited
840+
threads in membermask.
841+
- `ballot`: In the ballot form, the destination result is a 32 bit integer.
842+
In this form, the predicate from each thread in membermask are copied into
843+
the corresponding bit position of the result, where the bit position
844+
corresponds to the thread’s lane id.
845+
846+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync)
847+
}];
815848
string llvmBuilder = [{
816-
$res = createIntrinsicCall(builder,
817-
llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
849+
auto intId = getVoteSyncIntrinsicId($kind);
850+
$res = createIntrinsicCall(builder, intId, {$mask, $pred});
818851
}];
819-
let hasCustomAssemblyFormat = 1;
852+
let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
853+
let hasVerifier = 1;
820854
}
821855

822856
def NVVM_SyncWarpOp :

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,34 +48,6 @@ using namespace NVVM;
4848
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
4949
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
5050

51-
//===----------------------------------------------------------------------===//
52-
// Printing/parsing for NVVM ops
53-
//===----------------------------------------------------------------------===//
54-
55-
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
56-
p << " " << op->getOperands();
57-
if (op->getNumResults() > 0)
58-
p << " : " << op->getResultTypes();
59-
}
60-
61-
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
62-
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
63-
MLIRContext *context = parser.getContext();
64-
auto int32Ty = IntegerType::get(context, 32);
65-
auto int1Ty = IntegerType::get(context, 1);
66-
67-
SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
68-
Type type;
69-
return failure(parser.parseOperandList(ops) ||
70-
parser.parseOptionalAttrDict(result.attributes) ||
71-
parser.parseColonType(type) ||
72-
parser.addTypeToList(type, result.types) ||
73-
parser.resolveOperands(ops, {int32Ty, int1Ty},
74-
parser.getNameLoc(), result.operands));
75-
}
76-
77-
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
78-
7951
//===----------------------------------------------------------------------===//
8052
// Verifier methods
8153
//===----------------------------------------------------------------------===//
@@ -1160,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
11601132
return success();
11611133
}
11621134

1135+
LogicalResult NVVM::VoteSyncOp::verify() {
1136+
if (getKind() == NVVM::VoteSyncKind::ballot) {
1137+
if (!getType().isInteger(32)) {
1138+
return emitOpError("vote.sync 'ballot' returns an i32");
1139+
}
1140+
} else {
1141+
if (!getType().isInteger(1)) {
1142+
return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1143+
}
1144+
}
1145+
return success();
1146+
}
1147+
11631148
//===----------------------------------------------------------------------===//
11641149
// getIntrinsicID/getIntrinsicIDAndArgs methods
11651150
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
121121
}
122122
}
123123

124+
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
125+
switch (kind) {
126+
case NVVM::VoteSyncKind::any:
127+
return llvm::Intrinsic::nvvm_vote_any_sync;
128+
case NVVM::VoteSyncKind::all:
129+
return llvm::Intrinsic::nvvm_vote_all_sync;
130+
case NVVM::VoteSyncKind::ballot:
131+
return llvm::Intrinsic::nvvm_vote_ballot_sync;
132+
case NVVM::VoteSyncKind::uni:
133+
return llvm::Intrinsic::nvvm_vote_uni_sync;
134+
default:
135+
llvm_unreachable("unsupported vote kind");
136+
}
137+
}
138+
124139
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
125140
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
126141
int32_t num) {

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,14 @@ func.func @nvvm_shfl_pred(
129129

130130
// CHECK-LABEL: @nvvm_vote(
131131
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
132-
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
133-
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
132+
// CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
133+
%0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
134+
// CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
135+
%1 = nvvm.vote.sync all %arg0, %arg1 -> i1
136+
// CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
137+
%2 = nvvm.vote.sync any %arg0, %arg1 -> i1
138+
// CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
139+
%3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
134140
llvm.return %0 : i32
135141
}
136142

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
255255
// CHECK-LABEL: @nvvm_vote
256256
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
257257
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
258-
%3 = nvvm.vote.ballot.sync %0, %1 : i32
258+
%3 = nvvm.vote.sync ballot %0, %1 -> i32
259+
// CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
260+
%4 = nvvm.vote.sync all %0, %1 -> i1
261+
// CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
262+
%5 = nvvm.vote.sync any %0, %1 -> i1
263+
// CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
264+
%6 = nvvm.vote.sync uni %0, %1 -> i1
259265
llvm.return %3 : i32
260266
}
261267

0 commit comments

Comments
 (0)