Skip to content

Commit 2114068

Browse files
committed
[mlir][NVVM] Add ops for vote all and any sync
1 parent 7288f1b commit 2114068

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,26 @@ def NVVM_VoteBallotOp :
819819
let hasCustomAssemblyFormat = 1;
820820
}
821821

822+
def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
823+
Results<(outs LLVM_Type:$res)>,
824+
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
825+
string llvmBuilder = [{
826+
$res = createIntrinsicCall(builder,
827+
llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
828+
}];
829+
let hasCustomAssemblyFormat = 1;
830+
}
831+
832+
def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
833+
Results<(outs LLVM_Type:$res)>,
834+
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
835+
string llvmBuilder = [{
836+
$res = createIntrinsicCall(builder,
837+
llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
838+
}];
839+
let hasCustomAssemblyFormat = 1;
840+
}
841+
822842
def NVVM_SyncWarpOp :
823843
NVVM_Op<"bar.warp.sync">,
824844
Arguments<(ins LLVM_Type:$mask)> {

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
5858
p << " : " << op->getResultTypes();
5959
}
6060

61-
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
62-
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
61+
static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
6362
MLIRContext *context = parser.getContext();
6463
auto int32Ty = IntegerType::get(context, 32);
6564
auto int1Ty = IntegerType::get(context, 1);
@@ -74,8 +73,27 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7473
parser.getNameLoc(), result.operands));
7574
}
7675

76+
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
77+
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
78+
return parseVoteOps(parser, result);
79+
}
80+
7781
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7882

83+
// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
84+
ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
85+
return parseVoteOps(parser, result);
86+
}
87+
88+
void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
89+
90+
// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
91+
ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
92+
return parseVoteOps(parser, result);
93+
}
94+
95+
void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
96+
7997
//===----------------------------------------------------------------------===//
8098
// Verifier methods
8199
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ func.func @nvvm_shfl_pred(
131131
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
132132
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
133133
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
134+
// CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
135+
%1 = nvvm.vote.all.sync %arg0, %arg1 : i32
136+
// CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
137+
%2 = nvvm.vote.any.sync %arg0, %arg1 : i32
134138
llvm.return %0 : i32
135139
}
136140

0 commit comments

Comments
 (0)