Skip to content

Commit dcee6ef

Browse files
gryppsr-tream
authored andcommitted
[mlir][nvvm] Add cp.async.bulk.tensor.shared.cluster.global.multicast (llvm#72429)
This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster. It resolves llvm#72368
1 parent 474456d commit dcee6ef

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,13 +1405,29 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
14051405
Arguments<(ins LLVM_PointerShared:$dstMem,
14061406
LLVM_AnyPointer:$tmaDescriptor,
14071407
LLVM_PointerShared:$mbar,
1408+
Optional<I16>:$multicastMask,
14081409
Variadic<I32>:$coordinates,
14091410
PtxPredicate:$predicate)> {
1411+
let description = [{
1412+
Initiates an asynchronous copy operation on the tensor data from global
1413+
memory to shared memory.
1414+
1415+
The `multicastMask` operand is optional. When it is present, the Op copies
1416+
data from global memory to shared memory of multiple CTAs in the cluster.
1417+
Operand `multicastMask` specifies the destination CTAs in the cluster such
1418+
that each bit position in the 16-bit `multicastMask` operand corresponds to
1419+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
1420+
1421+
[For more information, see PTX ISA]
1422+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
1423+
}];
1424+
14101425
let assemblyFormat = [{
14111426
$dstMem `,`
14121427
$tmaDescriptor `,`
14131428
$mbar `,`
1414-
`box` `[`$coordinates `]`
1429+
( `multicast_mask` `=` $multicastMask^ `,` )?
1430+
`box` `[`$coordinates `]`
14151431
(`,` `predicate` `=` $predicate^)?
14161432
attr-dict `:` type(operands)
14171433
}];
@@ -1422,11 +1438,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
14221438
std::string ptx = "cp.async.bulk.tensor.";
14231439
ptx += std::to_string(dim) + "d.";
14241440
ptx += "shared::cluster.global.mbarrier::complete_tx::bytes";
1425-
if(dim == 1) ptx += " [%0], [%1, {%3} ], [%2];";
1426-
if(dim == 2) ptx += " [%0], [%1, {%3, %4} ], [%2];";
1427-
if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5} ], [%2];";
1428-
if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6} ], [%2];";
1429-
if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7} ], [%2];";
1441+
if(getMulticastMask()) {
1442+
ptx += ".multicast::cluster";
1443+
if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
1444+
if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
1445+
if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
1446+
if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
1447+
if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
1448+
} else {
1449+
if(dim == 1) ptx += " [%0], [%1, {%3} ], [%2];";
1450+
if(dim == 2) ptx += " [%0], [%1, {%3, %4} ], [%2];";
1451+
if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5} ], [%2];";
1452+
if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6} ], [%2];";
1453+
if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7} ], [%2];";
1454+
}
14301455
return ptx;
14311456
}
14321457
}];

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ struct NVGPUTmaAsyncLoadOpLowering
975975
}
976976

977977
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
978-
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords,
978+
op, dest, adaptor.getTensorMapDescriptor(), barrier, Value(), coords,
979979
adaptor.getPredicate());
980980
return success();
981981
}

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
130130
return
131131
}
132132

133+
// CHECK-LABEL: @tma_load_multicast1d
134+
func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
135+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
136+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
137+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
138+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
139+
return
140+
}
141+
142+
// CHECK-LABEL: @tma_load_multicast2d
143+
func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
144+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
145+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
146+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
147+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
148+
return
149+
}
150+
151+
// CHECK-LABEL: @tma_load_multicast3d
152+
func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
153+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
154+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
155+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
156+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
157+
return
158+
}
159+
160+
// CHECK-LABEL: @tma_load_multicast4d
161+
func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
162+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
163+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
164+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
165+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
166+
return
167+
}
168+
169+
// CHECK-LABEL: @tma_load_multicast5d
170+
func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
171+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
172+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
173+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
174+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
175+
return
176+
}
177+
133178
// CHECK-LABEL: @tma_store_1d
134179
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
135180
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"

0 commit comments

Comments
 (0)