-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][NVVM] Add ops for vote all and any sync #134309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-llvm Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesAdd operations for Full diff: https://github.com/llvm/llvm-project/pull/134309.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..4a549d02dc281 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -819,6 +819,26 @@ def NVVM_VoteBallotOp :
let hasCustomAssemblyFormat = 1;
}
+def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+ string llvmBuilder = [{
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+ string llvmBuilder = [{
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
def NVVM_SyncWarpOp :
NVVM_Op<"bar.warp.sync">,
Arguments<(ins LLVM_Type:$mask)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8ef74fcef90e8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -58,8 +58,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
MLIRContext *context = parser.getContext();
auto int32Ty = IntegerType::get(context, 32);
auto int1Ty = IntegerType::get(context, 1);
@@ -74,8 +73,27 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands));
}
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
+ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
+void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
+// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
+ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseVoteOps(parser, result);
+}
+
+void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf39424f0bf..9eec62ff67561 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -131,6 +131,10 @@ func.func @nvvm_shfl_pred(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
+ %1 = nvvm.vote.all.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
+ %2 = nvvm.vote.any.sync %arg0, %arg1 : i32
llvm.return %0 : i32
}
|
@@ -819,6 +819,26 @@ def NVVM_VoteBallotOp : | |||
let hasCustomAssemblyFormat = 1; | |||
} | |||
|
|||
def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use the existing op nvvm.vote.ballot.sync
and extend its mode. so it's close to ptx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I can update it this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a VoteSyncKind and also support for uni
version of the op.
2114068
to
1ade5b5
Compare
The overall change looks good. Thank you for extending the Op! |
@clementval @durga4github @grypp With this change I am seeing the following compilation warning for which I created a small fix in #134600:
|
…VoteSyncIntrinsicId` (#134600) Fixes the following warning after the changes in llvm/llvm-project#134309: ``` llvm-project/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp:134:3: warning: default label in switch which covers all enumeration values [-Wcovered-switch-default] default: ^ 1 warning generated. ```
Add operations for
nvvm.vote.all.sync
andnvvm.vote.any.sync
intrinsics similar tonvvm.vote.ballot.sync
.