Skip to content

[MLIR][NVVM] Add TMA Bulk Copy Ops #123186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,150 @@ def NVVM_CpAsyncBulkTensorReduceOp :
}];
}

def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
let summary = "Async bulk copy from global memory to Shared cluster memory";
let description = [{
Initiates an asynchronous copy operation from global memory to cluster's
shared memory.

The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Operand `multicastMask` specifies the destination CTAs in the cluster such
that each bit position in the 16-bit `multicastMask` operand corresponds to
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
}];

let arguments = (ins
LLVM_PointerShared:$dstMem,
LLVM_PointerGlobal:$srcMem,
LLVM_PointerShared:$mbar,
I32:$size,
Optional<I16>:$multicastMask,
Optional<I64>:$l2CacheHint);

let assemblyFormat = [{
$dstMem `,` $srcMem `,` $mbar `,` $size
(`multicast_mask` `=` $multicastMask^ )?
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($dstMem) `,` type($srcMem)
}];

string llvmBuilder = [{
// Arguments to the intrinsic:
// dst, mbar, src, size
// multicast_mask, cache_hint,
// flag for multicast_mask,
// flag for cache_hint
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($dstMem);
translatedOperands.push_back($mbar);
translatedOperands.push_back($srcMem);
translatedOperands.push_back($size);

// Multicast, if available
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
bool isMulticast = op.getMulticastMask() ? true : false;
translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);

// Cachehint, if available
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);

// Flag arguments for multicast and cachehint
translatedOperands.push_back(builder.getInt1(isMulticast));
translatedOperands.push_back(builder.getInt1(isCacheHint));

createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
}];
}

def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> {
let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory";
let description = [{
Initiates an asynchronous copy operation from Shared CTA memory to Shared
cluster memory.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
}];

let arguments = (ins
LLVM_PointerShared:$dstMem,
LLVM_PointerShared:$srcMem,
LLVM_PointerShared:$mbar,
I32:$size);

let assemblyFormat = [{
$dstMem `,` $srcMem `,` $mbar `,` $size
attr-dict `:` type($dstMem) `,` type($srcMem)
}];

string llvmBuilder = [{
createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster,
{$dstMem, $mbar, $srcMem, $size});
}];
}

def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
let summary = "Async bulk copy from Shared CTA memory to Global memory";
let description = [{
Initiates an asynchronous copy operation from Shared CTA memory to
global memory.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
}];

let arguments = (ins
LLVM_PointerGlobal:$dstMem,
LLVM_PointerShared:$srcMem,
I32:$size,
Optional<I64>:$l2CacheHint);

let assemblyFormat = [{
$dstMem `,` $srcMem `,` $size
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($dstMem) `,` type($srcMem)
}];

string llvmBuilder = [{
// Arguments to the intrinsic:
// dst, src, size, cache_hint,
// Flag for cache_hint
//
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($dstMem);
translatedOperands.push_back($srcMem);
translatedOperands.push_back($size);

// Cachehint, if available
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);

// Flag argument for cachehint
translatedOperands.push_back(builder.getInt1(isCacheHint));

createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
}];
}

//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cluster
llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %mc : i16, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i16 0, i64 0, i1 false, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 0, i64 %[[CH:.*]], i1 false, i1 true)
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC:.*]], i64 0, i1 true, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC]], i64 %[[CH]], i1 true, i1 true)
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>

nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>

nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc : !llvm.ptr<3>, !llvm.ptr<1>

nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(3) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
nvvm.cp.async.bulk.shared.cluster.shared.cta %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<3>
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>

nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
llvm.return
}
Loading