Skip to content

Commit 84f7241

Browse files
committed
[MLIR][NVVM] Add TMA Bulk Copy Ops
PR llvm#122344 adds intrinsics for Bulk Async Copy (non-tensor variants) using TMA. This patch adds the corresponding NVVM Dialect Ops. Signed-off-by: Durgadoss R <[email protected]>
1 parent c25bd6e commit 84f7241

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,6 +2138,150 @@ def NVVM_CpAsyncBulkTensorReduceOp :
21382138
}];
21392139
}
21402140

2141+
def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
2142+
NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
2143+
let summary = "Async bulk copy from global memory to Shared cluster memory";
2144+
let description = [{
2145+
Initiates an asynchronous copy operation from global memory to cluster's
2146+
shared memory.
2147+
2148+
The `multicastMask` operand is optional. When it is present, the Op copies
2149+
data from global memory to shared memory of multiple CTAs in the cluster.
2150+
Operand `multicastMask` specifies the destination CTAs in the cluster such
2151+
that each bit position in the 16-bit `multicastMask` operand corresponds to
2152+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
2153+
2154+
The `l2CacheHint` operand is optional, and it is used to specify cache
2155+
eviction policy that may be used during the memory access.
2156+
[For more information, see PTX ISA]
2157+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2158+
}];
2159+
2160+
let arguments = (ins
2161+
LLVM_PointerShared:$dstMem,
2162+
LLVM_PointerGlobal:$srcMem,
2163+
LLVM_PointerShared:$mbar,
2164+
I32:$size,
2165+
Optional<I16>:$multicastMask,
2166+
Optional<I64>:$l2CacheHint);
2167+
2168+
let assemblyFormat = [{
2169+
$dstMem `,` $srcMem `,` $mbar `,` $size
2170+
(`multicast_mask` `=` $multicastMask^ )?
2171+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2172+
attr-dict `:` type($dstMem) `,` type($srcMem)
2173+
}];
2174+
2175+
string llvmBuilder = [{
2176+
// Arguments to the intrinsic:
2177+
// dst, mbar, src, size
2178+
// multicast_mask, cache_hint,
2179+
// flag for multicast_mask,
2180+
// flag for cache_hint
2181+
llvm::SmallVector<llvm::Value *> translatedOperands;
2182+
translatedOperands.push_back($dstMem);
2183+
translatedOperands.push_back($mbar);
2184+
translatedOperands.push_back($srcMem);
2185+
translatedOperands.push_back($size);
2186+
2187+
// Multicast, if available
2188+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2189+
auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
2190+
bool isMulticast = op.getMulticastMask() ? true : false;
2191+
translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
2192+
2193+
// Cachehint, if available
2194+
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2195+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2196+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2197+
2198+
// Flag arguments for multicast and cachehint
2199+
translatedOperands.push_back(builder.getInt1(isMulticast));
2200+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2201+
2202+
createIntrinsicCall(builder,
2203+
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
2204+
}];
2205+
}
2206+
2207+
def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
2208+
NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> {
2209+
let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory";
2210+
let description = [{
2211+
Initiates an asynchronous copy operation from Shared CTA memory to Shared
2212+
cluster memory.
2213+
2214+
[For more information, see PTX ISA]
2215+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2216+
}];
2217+
2218+
let arguments = (ins
2219+
LLVM_PointerShared:$dstMem,
2220+
LLVM_PointerShared:$srcMem,
2221+
LLVM_PointerShared:$mbar,
2222+
I32:$size);
2223+
2224+
let assemblyFormat = [{
2225+
$dstMem `,` $srcMem `,` $mbar `,` $size
2226+
attr-dict `:` type($dstMem) `,` type($srcMem)
2227+
}];
2228+
2229+
string llvmBuilder = [{
2230+
createIntrinsicCall(builder,
2231+
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster,
2232+
{$dstMem, $mbar, $srcMem, $size});
2233+
}];
2234+
}
2235+
2236+
def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2237+
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
2238+
let summary = "Async bulk copy from Shared CTA memory to Global memory";
2239+
let description = [{
2240+
Initiates an asynchronous copy operation from Shared CTA memory to
2241+
global memory.
2242+
2243+
The `l2CacheHint` operand is optional, and it is used to specify cache
2244+
eviction policy that may be used during the memory access.
2245+
[For more information, see PTX ISA]
2246+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2247+
}];
2248+
2249+
let arguments = (ins
2250+
LLVM_PointerGlobal:$dstMem,
2251+
LLVM_PointerShared:$srcMem,
2252+
I32:$size,
2253+
Optional<I64>:$l2CacheHint);
2254+
2255+
let assemblyFormat = [{
2256+
$dstMem `,` $srcMem `,` $size
2257+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2258+
attr-dict `:` type($dstMem) `,` type($srcMem)
2259+
}];
2260+
2261+
string llvmBuilder = [{
2262+
// Arguments to the intrinsic:
2263+
// dst, src, size, cache_hint,
2264+
// Flag for cache_hint
2265+
//
2266+
llvm::SmallVector<llvm::Value *> translatedOperands;
2267+
translatedOperands.push_back($dstMem);
2268+
translatedOperands.push_back($srcMem);
2269+
translatedOperands.push_back($size);
2270+
2271+
// Cachehint, if available
2272+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2273+
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2274+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2275+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2276+
2277+
// Flag argument for cachehint
2278+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2279+
2280+
createIntrinsicCall(builder,
2281+
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
2282+
}];
2283+
}
2284+
21412285
//===----------------------------------------------------------------------===//
21422286
// NVVM Wgmma Ops
21432287
//===----------------------------------------------------------------------===//
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
3+
4+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cluster
5+
llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %mc : i16, %ch : i64) {
6+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i16 0, i64 0, i1 false, i1 false)
7+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 0, i64 %[[CH:.*]], i1 false, i1 true)
8+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC:.*]], i64 0, i1 true, i1 false)
9+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC]], i64 %[[CH]], i1 true, i1 true)
10+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
11+
12+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
13+
14+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc : !llvm.ptr<3>, !llvm.ptr<1>
15+
16+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
17+
llvm.return
18+
}
19+
20+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
21+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
22+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(3) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
23+
nvvm.cp.async.bulk.shared.cluster.shared.cta %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<3>
24+
llvm.return
25+
}
26+
27+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
28+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
29+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
30+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
31+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
32+
33+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
34+
llvm.return
35+
}

0 commit comments

Comments
 (0)