Skip to content

Commit 2feced1

Browse files
authored
[MLIR][NVVM] Add tcgen05 wait/fence Ops (#126265)
PR #126091 adds intrinsics for tcgen05 wait/fence/commit operations. This patch adds NVVM Dialect Ops for them. Signed-off-by: Durgadoss R <[email protected]>
1 parent 101b3ff commit 2feced1

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,6 +2617,30 @@ def Tcgen05GroupKindAttr :
26172617
let assemblyFormat = "`<` $value `>`";
26182618
}
26192619

2620+
def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
2621+
def Tcgen05FenceAfter : I32EnumAttrCase<"AFTER_THREAD_SYNC", 1, "after">;
2622+
def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
2623+
[Tcgen05FenceBefore, Tcgen05FenceAfter]> {
2624+
let genSpecializedAttr = 0;
2625+
let cppNamespace = "::mlir::NVVM";
2626+
}
2627+
def Tcgen05FenceKindAttr :
2628+
EnumAttr<NVVM_Dialect, Tcgen05FenceKind, "tcgen05_fence"> {
2629+
let assemblyFormat = "`<` $value `>`";
2630+
}
2631+
2632+
def Tcgen05WaitLoad : I32EnumAttrCase<"LOAD", 0, "load">;
2633+
def Tcgen05WaitStore : I32EnumAttrCase<"STORE", 1, "store">;
2634+
def Tcgen05WaitKind : I32EnumAttr<"Tcgen05WaitKind", "NVVM Tcgen05 wait kind",
2635+
[Tcgen05WaitLoad, Tcgen05WaitStore]> {
2636+
let genSpecializedAttr = 0;
2637+
let cppNamespace = "::mlir::NVVM";
2638+
}
2639+
def Tcgen05WaitKindAttr :
2640+
EnumAttr<NVVM_Dialect, Tcgen05WaitKind, "tcgen05_wait"> {
2641+
let assemblyFormat = "`<` $value `>`";
2642+
}
2643+
26202644
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
26212645
let summary = "Tcgen05 alloc operation";
26222646
let description = [{
@@ -2701,6 +2725,91 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
27012725
}];
27022726
}
27032727

2728+
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
2729+
let summary = "Tcgen05 fence operations";
2730+
let description = [{
2731+
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
2732+
with respect to the subsequent tcgen05 and execution ordering operations.
2733+
The `tcgen05.fence<after>` orders all subsequent async tcgen05 operations
2734+
with respect to the prior tcgen05 and execution ordering operations.
2735+
2736+
[For more information refer to the PTX ISA]
2737+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-fence)
2738+
}];
2739+
2740+
let arguments = (ins Tcgen05FenceKindAttr:$kind);
2741+
let assemblyFormat = "$kind attr-dict";
2742+
2743+
string llvmBuilder = [{
2744+
auto id = ($kind == NVVM::Tcgen05FenceKind::BEFORE_THREAD_SYNC)
2745+
? llvm::Intrinsic::nvvm_tcgen05_fence_before_thread_sync
2746+
: llvm::Intrinsic::nvvm_tcgen05_fence_after_thread_sync;
2747+
createIntrinsicCall(builder, id);
2748+
}];
2749+
}
2750+
2751+
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
2752+
let summary = "Tcgen05 wait operations";
2753+
let description = [{
2754+
The `tcgen05.wait<load>` causes the executing thread to block until
2755+
all prior `tcgen05.ld` operations issued by the executing thread
2756+
have completed. Similarly, the `tcgen05.wait<store>` causes the executing
2757+
thread to block until all prior `tcgen05.st` operations issued by the
2758+
executing thread have completed.
2759+
[For more information refer PTX ISA]
2760+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-wait)
2761+
}];
2762+
2763+
let arguments = (ins Tcgen05WaitKindAttr:$kind);
2764+
let assemblyFormat = "$kind attr-dict";
2765+
2766+
string llvmBuilder = [{
2767+
auto id = ($kind == NVVM::Tcgen05WaitKind::LOAD)
2768+
? llvm::Intrinsic::nvvm_tcgen05_wait_ld
2769+
: llvm::Intrinsic::nvvm_tcgen05_wait_st;
2770+
createIntrinsicCall(builder, id);
2771+
}];
2772+
}
2773+
2774+
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
2775+
let summary = "Tcgen05 commit operations";
2776+
let description = [{
2777+
The `tcgen05.commit` makes the mbarrier object, specified by
2778+
the operand `addr`, track the completion of all the prior
2779+
async-tcgen05 operations initiated by the executing thread.
2780+
The multicast variants allow signaling on the mbarrier objects
2781+
of multiple CTAs within the cluster. Operand `multicastMask`,
2782+
when present, specifies the destination CTAs in the cluster such
2783+
that each bit position in the 16-bit `multicastMask` operand
2784+
corresponds to the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
2785+
[For more information refer PTX ISA]
2786+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
2787+
}];
2788+
2789+
let arguments = (ins
2790+
AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr,
2791+
Optional<I16>:$multicastMask,
2792+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2793+
2794+
let assemblyFormat = [{
2795+
$addr (`,` `multicast_mask` `=` $multicastMask^)?
2796+
attr-dict `:` type(operands)
2797+
}];
2798+
2799+
let extraClassDeclaration = [{
2800+
static llvm::Intrinsic::ID
2801+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2802+
llvm::SmallVector<llvm::Value *> &args);
2803+
}];
2804+
2805+
string llvmBuilder = [{
2806+
llvm::SmallVector<llvm::Value *> args;
2807+
auto id = NVVM::Tcgen05CommitOp::getIntrinsicIDAndArgs(
2808+
*op, moduleTranslation, args);
2809+
createIntrinsicCall(builder, id, args);
2810+
}];
2811+
}
2812+
27042813
//===----------------------------------------------------------------------===//
27052814
// NVVM target attribute.
27062815
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,36 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
12841284
return id;
12851285
}
12861286

1287+
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1288+
is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1289+
: llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1290+
1291+
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1292+
has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1293+
: TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1294+
1295+
llvm::Intrinsic::ID
1296+
Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1297+
LLVM::ModuleTranslation &mt,
1298+
llvm::SmallVector<llvm::Value *> &args) {
1299+
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1300+
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1301+
.getAddressSpace();
1302+
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1303+
bool hasMulticast = curOp.getMulticastMask() ? true : false;
1304+
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1305+
1306+
auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1307+
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1308+
1309+
// Fill the Intrinsic Args
1310+
args.push_back(mt.lookupValue(curOp.getAddr()));
1311+
if (hasMulticast)
1312+
args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1313+
1314+
return id;
1315+
}
1316+
12871317
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
12881318
/// have ConstantRangeAttr.
12891319
static void nvvmInferResultRanges(Operation *op, Value result,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
3+
4+
// CHECK-LABEL: @llvm_nvvm_tcgen05_fence
5+
llvm.func @llvm_nvvm_tcgen05_fence() {
6+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.fence.before.thread.sync()
7+
nvvm.tcgen05.fence #nvvm.tcgen05_fence<before>
8+
9+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.fence.after.thread.sync()
10+
nvvm.tcgen05.fence #nvvm.tcgen05_fence<after>
11+
12+
llvm.return
13+
}
14+
15+
// CHECK-LABEL: @llvm_nvvm_tcgen05_wait
16+
llvm.func @llvm_nvvm_tcgen05_wait() {
17+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.wait.ld()
18+
nvvm.tcgen05.wait #nvvm.tcgen05_wait<load>
19+
20+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.wait.st()
21+
nvvm.tcgen05.wait #nvvm.tcgen05_wait<store>
22+
23+
llvm.return
24+
}
25+
26+
// CHECK-LABEL: @llvm_nvvm_tcgen05_commit_generic
27+
llvm.func @llvm_nvvm_tcgen05_commit_generic(%barrier : !llvm.ptr, %cta_mask : i16) {
28+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.cg1(ptr %{{.*}})
29+
nvvm.tcgen05.commit %barrier : !llvm.ptr
30+
31+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.cg2(ptr %{{.*}})
32+
nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr
33+
34+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg1(ptr %{{.*}}, i16 %{{.*}})
35+
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr, i16
36+
37+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg2(ptr %{{.*}}, i16 %{{.*}})
38+
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr, i16
39+
llvm.return
40+
}
41+
42+
// CHECK-LABEL: @llvm_nvvm_tcgen05_commit_shared
43+
llvm.func @llvm_nvvm_tcgen05_commit_shared(%barrier : !llvm.ptr<3>, %cta_mask : i16) {
44+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.shared.cg1(ptr addrspace(3) %{{.*}})
45+
nvvm.tcgen05.commit %barrier : !llvm.ptr<3>
46+
47+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.shared.cg2(ptr addrspace(3) %{{.*}})
48+
nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>
49+
50+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg1(ptr addrspace(3) %{{.*}}, i16 %{{.*}})
51+
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr<3>, i16
52+
53+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg2(ptr addrspace(3) %{{.*}}, i16 %{{.*}})
54+
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>, i16
55+
llvm.return
56+
}

0 commit comments

Comments
 (0)