Skip to content

Commit 0f9e913

Browse files
authored
[MLIR][NVVM] Add TMA Bulk Copy Ops (#123186)
PR #122344 adds intrinsics for Bulk Async Copy (non-tensor variants) using TMA. This patch adds the corresponding NVVM Dialect Ops. lit tests are added to verify the lowering to all variants of the intrinsics. Signed-off-by: Durgadoss R <[email protected]>
1 parent a588e20 commit 0f9e913

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,6 +2209,150 @@ def NVVM_CpAsyncBulkTensorReduceOp :
22092209
}];
22102210
}
22112211

2212+
def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
2213+
NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
2214+
let summary = "Async bulk copy from global memory to Shared cluster memory";
2215+
let description = [{
2216+
Initiates an asynchronous copy operation from global memory to cluster's
2217+
shared memory.
2218+
2219+
The `multicastMask` operand is optional. When it is present, the Op copies
2220+
data from global memory to shared memory of multiple CTAs in the cluster.
2221+
Operand `multicastMask` specifies the destination CTAs in the cluster such
2222+
that each bit position in the 16-bit `multicastMask` operand corresponds to
2223+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
2224+
2225+
The `l2CacheHint` operand is optional, and it is used to specify cache
2226+
eviction policy that may be used during the memory access.
2227+
[For more information, see PTX ISA]
2228+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2229+
}];
2230+
2231+
let arguments = (ins
2232+
LLVM_PointerShared:$dstMem,
2233+
LLVM_PointerGlobal:$srcMem,
2234+
LLVM_PointerShared:$mbar,
2235+
I32:$size,
2236+
Optional<I16>:$multicastMask,
2237+
Optional<I64>:$l2CacheHint);
2238+
2239+
let assemblyFormat = [{
2240+
$dstMem `,` $srcMem `,` $mbar `,` $size
2241+
(`multicast_mask` `=` $multicastMask^ )?
2242+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2243+
attr-dict `:` type($dstMem) `,` type($srcMem)
2244+
}];
2245+
2246+
string llvmBuilder = [{
2247+
// Arguments to the intrinsic:
2248+
// dst, mbar, src, size
2249+
// multicast_mask, cache_hint,
2250+
// flag for multicast_mask,
2251+
// flag for cache_hint
2252+
llvm::SmallVector<llvm::Value *> translatedOperands;
2253+
translatedOperands.push_back($dstMem);
2254+
translatedOperands.push_back($mbar);
2255+
translatedOperands.push_back($srcMem);
2256+
translatedOperands.push_back($size);
2257+
2258+
// Multicast, if available
2259+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2260+
auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
2261+
bool isMulticast = op.getMulticastMask() ? true : false;
2262+
translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
2263+
2264+
// Cachehint, if available
2265+
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2266+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2267+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2268+
2269+
// Flag arguments for multicast and cachehint
2270+
translatedOperands.push_back(builder.getInt1(isMulticast));
2271+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2272+
2273+
createIntrinsicCall(builder,
2274+
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
2275+
}];
2276+
}
2277+
2278+
def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
2279+
NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> {
2280+
let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory";
2281+
let description = [{
2282+
Initiates an asynchronous copy operation from Shared CTA memory to Shared
2283+
cluster memory.
2284+
2285+
[For more information, see PTX ISA]
2286+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2287+
}];
2288+
2289+
let arguments = (ins
2290+
LLVM_PointerShared:$dstMem,
2291+
LLVM_PointerShared:$srcMem,
2292+
LLVM_PointerShared:$mbar,
2293+
I32:$size);
2294+
2295+
let assemblyFormat = [{
2296+
$dstMem `,` $srcMem `,` $mbar `,` $size
2297+
attr-dict `:` type($dstMem) `,` type($srcMem)
2298+
}];
2299+
2300+
string llvmBuilder = [{
2301+
createIntrinsicCall(builder,
2302+
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster,
2303+
{$dstMem, $mbar, $srcMem, $size});
2304+
}];
2305+
}
2306+
2307+
def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2308+
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
2309+
let summary = "Async bulk copy from Shared CTA memory to Global memory";
2310+
let description = [{
2311+
Initiates an asynchronous copy operation from Shared CTA memory to
2312+
global memory.
2313+
2314+
The `l2CacheHint` operand is optional, and it is used to specify cache
2315+
eviction policy that may be used during the memory access.
2316+
[For more information, see PTX ISA]
2317+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2318+
}];
2319+
2320+
let arguments = (ins
2321+
LLVM_PointerGlobal:$dstMem,
2322+
LLVM_PointerShared:$srcMem,
2323+
I32:$size,
2324+
Optional<I64>:$l2CacheHint);
2325+
2326+
let assemblyFormat = [{
2327+
$dstMem `,` $srcMem `,` $size
2328+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2329+
attr-dict `:` type($dstMem) `,` type($srcMem)
2330+
}];
2331+
2332+
string llvmBuilder = [{
2333+
// Arguments to the intrinsic:
2334+
// dst, src, size, cache_hint,
2335+
// Flag for cache_hint
2336+
//
2337+
llvm::SmallVector<llvm::Value *> translatedOperands;
2338+
translatedOperands.push_back($dstMem);
2339+
translatedOperands.push_back($srcMem);
2340+
translatedOperands.push_back($size);
2341+
2342+
// Cachehint, if available
2343+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2344+
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2345+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2346+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2347+
2348+
// Flag argument for cachehint
2349+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2350+
2351+
createIntrinsicCall(builder,
2352+
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
2353+
}];
2354+
}
2355+
22122356
//===----------------------------------------------------------------------===//
22132357
// NVVM Wgmma Ops
22142358
//===----------------------------------------------------------------------===//
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
3+
4+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cluster
5+
llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %mc : i16, %ch : i64) {
6+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i16 0, i64 0, i1 false, i1 false)
7+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 0, i64 %[[CH:.*]], i1 false, i1 true)
8+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC:.*]], i64 0, i1 true, i1 false)
9+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC]], i64 %[[CH]], i1 true, i1 true)
10+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
11+
12+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
13+
14+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc : !llvm.ptr<3>, !llvm.ptr<1>
15+
16+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
17+
llvm.return
18+
}
19+
20+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
21+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
22+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(3) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
23+
nvvm.cp.async.bulk.shared.cluster.shared.cta %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<3>
24+
llvm.return
25+
}
26+
27+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
28+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
29+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
30+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
31+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
32+
33+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
34+
llvm.return
35+
}

0 commit comments

Comments
 (0)