Skip to content

[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics #78900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2024

Conversation

durga4github
Copy link
Contributor

This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.

  • Doc updated for the commit_group Op.
  • Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.

This patch updates the cp.async.bulk.{commit/wait}_group
Ops to use NVVM intrinsics.
* Doc updated for the commit_group Op.
* Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the
'nvvm.setmaxregister' test.

Signed-off-by: Durgadoss R <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.

  • Doc updated for the commit_group Op.
  • Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.


Full diff: https://github.com/llvm/llvm-project/pull/78900.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+18-12)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+7-11)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+22-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..3916896382163ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+  let description = [{
+    This Op commits all prior initiated but uncommitted cp.async.bulk
+    instructions into a cp.async.bulk-group.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+  }];
+
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
   }];
 }
 
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
   Arguments<(ins 
     ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 
-    OptionalAttr<UnitAttr>:$read)> 
-{
+    OptionalAttr<UnitAttr>:$read)> {
   let assemblyFormat = "$group attr-dict";
   let description = [{
     Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
   }];
   
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { 
-      auto ptx = std::string("cp.async.bulk.wait_group");
-      if(getRead()) ptx += ".read";
-      ptx += " %0;"; return ptx; }
+  string llvmBuilder = [{
+    auto intId = op.getRead() ?
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+    createIntrinsicCall(builder, intId, builder.getInt32($group));
   }];
 }
 
-
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218a..40131af6826487a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
 
 // -----
 
-func.func @cp_bulk_commit() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+  // CHECK: nvvm.cp.async.bulk.commit.group
   nvvm.cp.async.bulk.commit.group
   func.return
 }
 
 // -----
 
-func.func @cp_bulk_wait_group() {
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
-  // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
-  // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+  // CHECK: nvvm.cp.async.bulk.wait_group 1
+  // CHECK: nvvm.cp.async.bulk.wait_group 0
+  // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+  // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
   nvvm.cp.async.bulk.wait_group 1
   nvvm.cp.async.bulk.wait_group 0
   nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..49f9426daabc21b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 
 // CHECK-LABEL: @llvm_nvvm_setmaxregister
 llvm.func @llvm_nvvm_setmaxregister() {
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
   nvvm.setmaxregister increase 256
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
   nvvm.setmaxregister decrease 24
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+  nvvm.cp.async.bulk.commit.group
+  llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+  nvvm.cp.async.bulk.wait_group 0
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+  nvvm.cp.async.bulk.wait_group 3
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+  nvvm.cp.async.bulk.wait_group 0 {read}
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+  nvvm.cp.async.bulk.wait_group 3 {read}
+  llvm.return
+}
+
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2024

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

Changes

This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.

  • Doc updated for the commit_group Op.
  • Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.


Full diff: https://github.com/llvm/llvm-project/pull/78900.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+18-12)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+7-11)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+22-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..3916896382163e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+  let description = [{
+    This Op commits all prior initiated but uncommitted cp.async.bulk
+    instructions into a cp.async.bulk-group.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+  }];
+
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
   }];
 }
 
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
   Arguments<(ins 
     ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 
-    OptionalAttr<UnitAttr>:$read)> 
-{
+    OptionalAttr<UnitAttr>:$read)> {
   let assemblyFormat = "$group attr-dict";
   let description = [{
     Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
   }];
   
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { 
-      auto ptx = std::string("cp.async.bulk.wait_group");
-      if(getRead()) ptx += ".read";
-      ptx += " %0;"; return ptx; }
+  string llvmBuilder = [{
+    auto intId = op.getRead() ?
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+    createIntrinsicCall(builder, intId, builder.getInt32($group));
   }];
 }
 
-
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218..40131af6826487 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
 
 // -----
 
-func.func @cp_bulk_commit() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+  // CHECK: nvvm.cp.async.bulk.commit.group
   nvvm.cp.async.bulk.commit.group
   func.return
 }
 
 // -----
 
-func.func @cp_bulk_wait_group() {
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
-  // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
-  // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+  // CHECK: nvvm.cp.async.bulk.wait_group 1
+  // CHECK: nvvm.cp.async.bulk.wait_group 0
+  // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+  // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
   nvvm.cp.async.bulk.wait_group 1
   nvvm.cp.async.bulk.wait_group 0
   nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f6..49f9426daabc21 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 
 // CHECK-LABEL: @llvm_nvvm_setmaxregister
 llvm.func @llvm_nvvm_setmaxregister() {
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
   nvvm.setmaxregister increase 256
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
   nvvm.setmaxregister decrease 24
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+  nvvm.cp.async.bulk.commit.group
+  llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+  nvvm.cp.async.bulk.wait_group 0
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+  nvvm.cp.async.bulk.wait_group 3
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+  nvvm.cp.async.bulk.wait_group 0 {read}
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+  nvvm.cp.async.bulk.wait_group 3 {read}
+  llvm.return
+}
+
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})

@durga4github
Copy link
Contributor Author

@grypp , Please help review.

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks

@grypp grypp merged commit aa4547f into llvm:main Jan 22, 2024
@durga4github durga4github deleted the durgadossr/mlir_cp_async_bulk_commit branch January 22, 2024 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants