Skip to content

Commit ae5d639

Browse files
authored
[mlir][nvvm] Introduce cp.async.bulk.wait_group (#77917)
1 parent 59d6f03 commit ae5d639

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,35 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.gro
15471547
}];
15481548
}
15491549

1550+
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
1551+
Arguments<(ins
1552+
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
1553+
OptionalAttr<UnitAttr>:$read)>
1554+
{
1555+
let assemblyFormat = "$group attr-dict";
1556+
let description = [{
1557+
Op waits for completion of the most recent bulk async-groups.
1558+
1559+
The `$group` operand tells waiting has to be done until for $group or fewer
1560+
of the most recent bulk async-groups. If `$group` is 0, the op wait until
1561+
all the most recent bulk async-groups have completed.
1562+
1563+
The `$read` indicates that the waiting has to be done until all the bulk
1564+
async operations in the specified bulk async-group have completed reading
1565+
from their source locations.
1566+
1567+
[For more information, see PTX ISA]
1568+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
1569+
}];
1570+
1571+
let extraClassDefinition = [{
1572+
std::string $cppClass::getPtx() {
1573+
auto ptx = std::string("cp.async.bulk.wait_group");
1574+
if(getRead()) ptx += ".read";
1575+
ptx += " %0;"; return ptx; }
1576+
}];
1577+
}
1578+
15501579

15511580
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
15521581
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,23 @@ func.func @cp_bulk_commit() {
644644
func.return
645645
}
646646

647+
// -----
648+
649+
func.func @cp_bulk_wait_group() {
650+
// CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
651+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
652+
// CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
653+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
654+
// CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
655+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
656+
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
657+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
658+
nvvm.cp.async.bulk.wait_group 1
659+
nvvm.cp.async.bulk.wait_group 0
660+
nvvm.cp.async.bulk.wait_group 5 {read}
661+
nvvm.cp.async.bulk.wait_group 0 {read}
662+
func.return
663+
}
647664

648665
// -----
649666

0 commit comments

Comments
 (0)