Skip to content

[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics #78900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//

def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
let description = [{
This Op commits all prior initiated but uncommitted cp.async.bulk
instructions into a cp.async.bulk-group.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
}];

string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}

def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
OptionalAttr<UnitAttr>:$read)>
{
OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
Expand All @@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];

let extraClassDefinition = [{
std::string $cppClass::getPtx() {
auto ptx = std::string("cp.async.bulk.wait_group");
if(getRead()) ptx += ".read";
ptx += " %0;"; return ptx; }
string llvmBuilder = [{
auto intId = op.getRead() ?
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}


def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
Expand Down
18 changes: 7 additions & 11 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -638,23 +638,19 @@ func.func @set_max_register() {

// -----

func.func @cp_bulk_commit() {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
func.func @cp_async_bulk_commit() {
// CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}

// -----

func.func @cp_bulk_wait_group() {
// CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
// CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
// CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
func.func @cp_async_bulk_wait_group() {
// CHECK: nvvm.cp.async.bulk.wait_group 1
// CHECK: nvvm.cp.async.bulk.wait_group 0
// CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
// CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
Expand Down
24 changes: 22 additions & 2 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p

// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
// CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
// CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
nvvm.cp.async.bulk.commit.group
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
nvvm.cp.async.bulk.wait_group 0
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
nvvm.cp.async.bulk.wait_group 3
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
nvvm.cp.async.bulk.wait_group 0 {read}
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
nvvm.cp.async.bulk.wait_group 3 {read}
llvm.return
}

// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
Expand Down