-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][nvvm] Improve cp.async.bulk.tensor.shared.cluster.global
for multicast
#72429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster. It resolves llvm#72368
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Guray Ozen (grypp) ChangesThis PR introduce It resolves #72368 Full diff: https://github.com/llvm/llvm-project/pull/72429.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
+ Arguments<(ins LLVM_PointerShared:$dstMem,
+ LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerShared:$mbar,
+ I16:$multicastMask,
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ $multicastMask `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int dim = getCoordinates().size();
+ std::string ptx = "cp.async.bulk.tensor.";
+ ptx += std::to_string(dim) + "d.";
+ ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+ if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+ if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+ if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+ if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+ if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return emitError("Maximum 5 coordinates and dimension is supported.");
return success();
}
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+ if (getCoordinates().size() > 5)
+ return emitError("Maximum 5 coordinates and dimension is supported.");
+ return success();
+}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
return
}
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+ return
+}
+
// CHECK-LABEL: @tma_store_1d
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
|
@llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis PR introduce It resolves #72368 Full diff: https://github.com/llvm/llvm-project/pull/72429.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
+ Arguments<(ins LLVM_PointerShared:$dstMem,
+ LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerShared:$mbar,
+ I16:$multicastMask,
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ $multicastMask `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int dim = getCoordinates().size();
+ std::string ptx = "cp.async.bulk.tensor.";
+ ptx += std::to_string(dim) + "d.";
+ ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+ if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+ if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+ if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+ if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+ if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return emitError("Maximum 5 coordinates and dimension is supported.");
return success();
}
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+ if (getCoordinates().size() > 5)
+ return emitError("Maximum 5 coordinates and dimension is supported.");
+ return success();
+}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
return
}
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+ return
+}
+
// CHECK-LABEL: @tma_store_1d
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
|
}]; | ||
let hasVerifier = 1; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to make it another op? Can't the existing op below be extended to have the multi-cast mark as an attribute and generate the right PTX in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for extending the existing Op.
I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).
However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was about to add this idea to the PR to gather our preference. Using the same Op is perfectly fine for this.
We've divided the original PTX instruction into two parts:
load (nvvm.cp.async.bulk.tensor.shared.cluster.global
)
store (nvvm.cp.async.bulk.tensor.global.shared.cta
)
I believe separating these two makes sense since they address different concerns.
I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always).
However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).
I think this is the way I will go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the PR to use the existing Op.
Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.
For example the current op is below:
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
with multicast_mask
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
with multicast_mask
+ l2_cache_hint
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
with multicast_mask
+ l2_cache_hint
+ im2col
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
Same as above with predicate
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the PR to use the existing Op.
The updated version looks good to me.
Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.
I do not see any concerns. We can extend it the same way for cache-hint.
I believe, im2col itself will be a variadic type (since it can be of size 1,2,3). So, as long as we can have an operand that's both variadic + optional, we are good with this direction.
For example the current op is below:
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
with
multicast_mask
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
with
multicast_mask
+l2_cache_hint
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, l2_cache_hint = %cache, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
with
multicast_mask
+l2_cache_hint
+im2col
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, l2_cache_hint = %cache, box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
Same as above with
predicate
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3], predicate = %p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1 nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], predicate = %p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1 nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, l2_cache_hint = %cache, box [%crd0,%crd1,%crd2,%crd3], predicate = %p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1 nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, l2_cache_hint = %cache, box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], predicate = %p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shorter mnemonic that capture the operation, and then using attributes seems more like MLIR to me :)
Likely more friendly for the user to create as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I have permission to click "approve" here.
This is +1/good-to-go from my side.
cp.async.bulk.tensor.shared.cluster.global.multicast
cp.async.bulk.tensor.shared.cluster.global
for multicast
…t` (llvm#72429) This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster. It resolves llvm#72368
…t` (llvm#72429) This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster. It resolves llvm#72368
This PR improves
cp.async.bulk.tensor.shared.cluster.global
Op in NVVM dialect to leverage multicast. When the multicast parameter is present, the Op loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.It resolves #72368