Skip to content

[mlir][nvvm] Improve cp.async.bulk.tensor.shared.cluster.global for multicast #72429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 16, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Nov 15, 2023

This PR improves cp.async.bulk.tensor.shared.cluster.global Op in NVVM dialect to leverage multicast. When the multicast parameter is present, the Op loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.

It resolves #72368

This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.

It resolves llvm#72368
@llvmbot
Copy link
Member

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

Changes

This PR introduce cp.async.bulk.tensor.shared.cluster.global.multicast Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.

It resolves #72368


Full diff: https://github.com/llvm/llvm-project/pull/72429.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+5)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+45)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
     return emitError("Maximum 5 coordinates and dimension is supported.");
   return success();
 }
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+  if (getCoordinates().size() > 5)
+    return emitError("Maximum 5 coordinates and dimension is supported.");
+  return success();
+}
 
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
   return
 }
 
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+  return
+}
+
 // CHECK-LABEL: @tma_store_1d
 func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"

@llvmbot
Copy link
Member

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

This PR introduce cp.async.bulk.tensor.shared.cluster.global.multicast Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.

It resolves #72368


Full diff: https://github.com/llvm/llvm-project/pull/72429.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+5)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+45)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
     return emitError("Maximum 5 coordinates and dimension is supported.");
   return success();
 }
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+  if (getCoordinates().size() > 5)
+    return emitError("Maximum 5 coordinates and dimension is supported.");
+  return success();
+}
 
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
   return
 }
 
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+  return
+}
+
 // CHECK-LABEL: @tma_store_1d
 func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"

}];
let hasVerifier = 1;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to make it another op? Can't the existing op below be extended to have the multi-cast mark as an attribute and generate the right PTX in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for extending the existing Op.

I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).

However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to add this idea to the PR to gather our preference. Using the same Op is perfectly fine for this.

We've divided the original PTX instruction into two parts:
load (nvvm.cp.async.bulk.tensor.shared.cluster.global)
store (nvvm.cp.async.bulk.tensor.global.shared.cta)

I believe separating these two makes sense since they address different concerns.

I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).
However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).

I think this is the way I will go.

Copy link
Member Author

@grypp grypp Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR to use the existing Op.

Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.

For example the current op is below:

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32

with multicast_mask

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32

with multicast_mask + l2_cache_hint

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32

with multicast_mask + l2_cache_hint + im2col

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16

Same as above with predicate

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR to use the existing Op.

The updated version looks good to me.

Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.

I do not see any concerns. We can extend it the same way for cache-hint.

I believe, im2col itself will be a variadic type (since it can be of size 1,2,3). So, as long as we can have an operand that's both variadic + optional, we are good with this direction.

For example the current op is below:

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32

with multicast_mask

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32

with multicast_mask + l2_cache_hint

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32

with multicast_mask + l2_cache_hint + im2col

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16

Same as above with predicate

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter mnemonic that capture the operation, and then using attributes seems more like MLIR to me :)

Likely more friendly for the user to create as well!

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I have permission to click "approve" here.

This is +1/good-to-go from my side.

@grypp grypp merged commit 108380d into llvm:main Nov 16, 2023
@grypp grypp changed the title [mlir][nvvm] Add cp.async.bulk.tensor.shared.cluster.global.multicast [mlir][nvvm] Improve cp.async.bulk.tensor.shared.cluster.global for multicast Nov 16, 2023
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
…t` (llvm#72429)

This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast`
Op in NVVM dialect. It loads data using TMA data from global memory to
shared memory of multiple CTAs in the cluster.

It resolves llvm#72368
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…t` (llvm#72429)

This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast`
Op in NVVM dialect. It loads data using TMA data from global memory to
shared memory of multiple CTAs in the cluster.

It resolves llvm#72368
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NVVM Dialect NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp is missing the multicast operand
4 participants