-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics #78900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics #78900
Conversation
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics. * Doc updated for the commit_group Op. * Tests are added to verify the lowering to the intrinsics. While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Signed-off-by: Durgadoss R <[email protected]>
@llvm/pr-subscribers-mlir-llvm Author: Durgadoss R (durga4github) ChangesThis patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.
While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Full diff: https://github.com/llvm/llvm-project/pull/78900.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..3916896382163ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+ let description = [{
+ This Op commits all prior initiated but uncommitted cp.async.bulk
+ instructions into a cp.async.bulk-group.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+ }];
+
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
- OptionalAttr<UnitAttr>:$read)>
-{
+ OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- auto ptx = std::string("cp.async.bulk.wait_group");
- if(getRead()) ptx += ".read";
- ptx += " %0;"; return ptx; }
+ string llvmBuilder = [{
+ auto intId = op.getRead() ?
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+ createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}
-
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218a..40131af6826487a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
// -----
-func.func @cp_bulk_commit() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+ // CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}
// -----
-func.func @cp_bulk_wait_group() {
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
- // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
- // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+ // CHECK: nvvm.cp.async.bulk.wait_group 1
+ // CHECK: nvvm.cp.async.bulk.wait_group 0
+ // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+ // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..49f9426daabc21b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+ nvvm.cp.async.bulk.commit.group
+ llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+ nvvm.cp.async.bulk.wait_group 0
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+ nvvm.cp.async.bulk.wait_group 3
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+ nvvm.cp.async.bulk.wait_group 0 {read}
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+ nvvm.cp.async.bulk.wait_group 3 {read}
+ llvm.return
+}
+
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
|
@llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesThis patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.
While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Full diff: https://github.com/llvm/llvm-project/pull/78900.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..3916896382163e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+ let description = [{
+ This Op commits all prior initiated but uncommitted cp.async.bulk
+ instructions into a cp.async.bulk-group.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+ }];
+
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
- OptionalAttr<UnitAttr>:$read)>
-{
+ OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- auto ptx = std::string("cp.async.bulk.wait_group");
- if(getRead()) ptx += ".read";
- ptx += " %0;"; return ptx; }
+ string llvmBuilder = [{
+ auto intId = op.getRead() ?
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+ createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}
-
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218..40131af6826487 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
// -----
-func.func @cp_bulk_commit() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+ // CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}
// -----
-func.func @cp_bulk_wait_group() {
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
- // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
- // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+ // CHECK: nvvm.cp.async.bulk.wait_group 1
+ // CHECK: nvvm.cp.async.bulk.wait_group 0
+ // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+ // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f6..49f9426daabc21 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+ nvvm.cp.async.bulk.commit.group
+ llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+ nvvm.cp.async.bulk.wait_group 0
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+ nvvm.cp.async.bulk.wait_group 3
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+ nvvm.cp.async.bulk.wait_group 0 {read}
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+ nvvm.cp.async.bulk.wait_group 3 {read}
+ llvm.return
+}
+
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
|
@grypp , Please help review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.
While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.