Skip to content

[MLIR][NVVM] Extend TMA Bulk Copy Op #140232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ enum NVVMMemorySpace {
kSharedClusterMemorySpace = 7,
};

/// A pair type of LLVM's Intrinsic ID and args (which are llvm values).
/// This type is returned by the getIntrinsicIDAndArgs() methods.
using IDArgPair =
std::pair<llvm::Intrinsic::ID, llvm::SmallVector<llvm::Value *>>;

/// Return the element type and number of elements associated with a wmma matrix
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
/// WMMA_REGS structure.
Expand Down
66 changes: 39 additions & 27 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2599,51 +2599,63 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
}

def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> {
let summary = "Async bulk copy from Shared CTA memory to Global memory";
let description = [{
Initiates an asynchronous copy operation from Shared CTA memory to
global memory.
global memory. The 32-bit operand `size` specifies the amount of
memory to be copied, in terms of number of bytes. `size` must be a
multiple of 16. The `l2CacheHint` operand is optional, and it is used
to specify cache eviction policy that may be used during the memory
access. The `byteMask` operand is optional. The i-th bit in the 16-bit
wide `byteMask` specifies whether the i-th byte of each 16-byte wide
chunk of source data is copied to the destination. If the bit is set,
the byte is copied.

Example:
```mlir
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size
: !llvm.ptr<1>, !llvm.ptr<3>

// with l2_cache_hint
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch
: !llvm.ptr<1>, !llvm.ptr<3>

// with byte_mask
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask
: !llvm.ptr<1>, !llvm.ptr<3>

// with both l2_cache_hint and byte_mask
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask
: !llvm.ptr<1>, !llvm.ptr<3>
```

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
}];

let arguments = (ins
LLVM_PointerGlobal:$dstMem,
LLVM_PointerShared:$srcMem,
I32:$size,
Optional<I64>:$l2CacheHint);
Optional<I64>:$l2CacheHint,
Optional<I16>:$byteMask);

let assemblyFormat = [{
$dstMem `,` $srcMem `,` $size
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($dstMem) `,` type($srcMem)
(`byte_mask` `=` $byteMask^ )?
attr-dict `:` type($dstMem) `,` type($srcMem)
}];

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];
string llvmBuilder = [{
// Arguments to the intrinsic:
// dst, src, size, cache_hint,
// Flag for cache_hint
//
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($dstMem);
translatedOperands.push_back($srcMem);
translatedOperands.push_back($size);

// Cachehint, if available
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);

// Flag argument for cachehint
translatedOperands.push_back(builder.getInt1(isCacheHint));

createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
auto [id, args] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
createIntrinsicCall(builder, id, args);
}];
}

Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,34 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
return id;
}

mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
llvm::SmallVector<llvm::Value *> args;
llvm::Intrinsic::ID id =
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;

// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getDstMem()));
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));

mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));

// Choose the bytemask variant
if (mlir::Value byteMask = thisOp.getByteMask()) {
args.push_back(mt.lookupValue(byteMask));
id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
}

return {id, std::move(args)};
}

llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
bool isIm2Col) {
switch (tensorDims) {
Expand Down
12 changes: 11 additions & 1 deletion mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>

nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]])
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]])
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>

nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
llvm.return
}