Skip to content

Commit aa4547f

Browse files
authored
[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (#78900)
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics. * Doc updated for the commit_group Op. * Tests are added to verify the lowering to the intrinsics. While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Signed-off-by: Durgadoss R <[email protected]>
1 parent 12c241b commit aa4547f

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
15911591
// NVVM TMA Ops
15921592
//===----------------------------------------------------------------------===//
15931593

1594-
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
1594+
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
15951595
Arguments<(ins )> {
15961596
let assemblyFormat = "attr-dict";
1597-
let extraClassDefinition = [{
1598-
std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
1597+
let description = [{
1598+
This Op commits all prior initiated but uncommitted cp.async.bulk
1599+
instructions into a cp.async.bulk-group.
1600+
1601+
[For more information, see PTX ISA]
1602+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
1603+
}];
1604+
1605+
string llvmBuilder = [{
1606+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
15991607
}];
16001608
}
16011609

1602-
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
1610+
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
16031611
Arguments<(ins
16041612
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
1605-
OptionalAttr<UnitAttr>:$read)>
1606-
{
1613+
OptionalAttr<UnitAttr>:$read)> {
16071614
let assemblyFormat = "$group attr-dict";
16081615
let description = [{
16091616
Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
16201627
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
16211628
}];
16221629

1623-
let extraClassDefinition = [{
1624-
std::string $cppClass::getPtx() {
1625-
auto ptx = std::string("cp.async.bulk.wait_group");
1626-
if(getRead()) ptx += ".read";
1627-
ptx += " %0;"; return ptx; }
1630+
string llvmBuilder = [{
1631+
auto intId = op.getRead() ?
1632+
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
1633+
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
1634+
createIntrinsicCall(builder, intId, builder.getInt32($group));
16281635
}];
16291636
}
16301637

1631-
16321638
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
16331639
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
16341640
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -638,23 +638,19 @@ func.func @set_max_register() {
638638

639639
// -----
640640

641-
func.func @cp_bulk_commit() {
642-
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
641+
func.func @cp_async_bulk_commit() {
642+
// CHECK: nvvm.cp.async.bulk.commit.group
643643
nvvm.cp.async.bulk.commit.group
644644
func.return
645645
}
646646

647647
// -----
648648

649-
func.func @cp_bulk_wait_group() {
650-
// CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
651-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
652-
// CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
653-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
654-
// CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
655-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
656-
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
657-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
649+
func.func @cp_async_bulk_wait_group() {
650+
// CHECK: nvvm.cp.async.bulk.wait_group 1
651+
// CHECK: nvvm.cp.async.bulk.wait_group 0
652+
// CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
653+
// CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
658654
nvvm.cp.async.bulk.wait_group 1
659655
nvvm.cp.async.bulk.wait_group 0
660656
nvvm.cp.async.bulk.wait_group 5 {read}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
398398

399399
// CHECK-LABEL: @llvm_nvvm_setmaxregister
400400
llvm.func @llvm_nvvm_setmaxregister() {
401-
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
401+
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
402402
nvvm.setmaxregister increase 256
403-
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
403+
// CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
404404
nvvm.setmaxregister decrease 24
405405
llvm.return
406406
}
407407

408+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
409+
llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
410+
// CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
411+
nvvm.cp.async.bulk.commit.group
412+
llvm.return
413+
}
414+
415+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
416+
llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
417+
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
418+
nvvm.cp.async.bulk.wait_group 0
419+
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
420+
nvvm.cp.async.bulk.wait_group 3
421+
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
422+
nvvm.cp.async.bulk.wait_group 0 {read}
423+
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
424+
nvvm.cp.async.bulk.wait_group 3 {read}
425+
llvm.return
426+
}
427+
408428
// CHECK-LABEL: @ld_matrix
409429
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
410430
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})

0 commit comments

Comments
 (0)