Skip to content

Commit 486f9e0

Browse files
committed
[MLIR][NVVM] Add Ops for tcgen05 cp and shift
PR 127669 adds intrinsics for tcgen05.cp/shift. This PR adds NVVM Dialect Ops for the same. lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R <[email protected]>
1 parent 3ce2e4d commit 486f9e0

File tree

5 files changed

+363
-0
lines changed

5 files changed

+363
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,6 +2810,113 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
28102810
}];
28112811
}
28122812

2813+
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
2814+
let summary = "Tcgen05 shift operation";
2815+
let description = [{
2816+
The `tcgen05.shift` is an asynchronous instruction which initiates
2817+
the shifting of 32-byte elements downwards across all the rows,
2818+
except the last, by one row. The operand `taddr` specifies the base
2819+
address of the matrix in Tensor Memory whose rows must be down shifted.
2820+
[For more information refer to the PTX ISA]
2821+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
2822+
}];
2823+
2824+
let arguments = (ins LLVM_PointerTensor:$taddr,
2825+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2826+
2827+
let assemblyFormat = "$taddr attr-dict `:` type(operands)";
2828+
2829+
string llvmBuilder = [{
2830+
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
2831+
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
2832+
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
2833+
createIntrinsicCall(builder, id, {$taddr});
2834+
}];
2835+
}
2836+
2837+
def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
2838+
def Shape4x256b : I32EnumAttrCase<"SHAPE_4x256b", 1, "shape_4x256b">;
2839+
def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
2840+
def Shape64x128b : I32EnumAttrCase<"SHAPE_64x128b", 3, "shape_64x128b">;
2841+
def Shape32x128b : I32EnumAttrCase<"SHAPE_32x128b", 4, "shape_32x128b">;
2842+
2843+
def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
2844+
[Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
2845+
let cppNamespace = "::mlir::NVVM";
2846+
let genSpecializedAttr = 0;
2847+
}
2848+
def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
2849+
let assemblyFormat = "`<` $value `>`";
2850+
}
2851+
2852+
def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
2853+
def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
2854+
def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
2855+
def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
2856+
2857+
def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
2858+
[Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
2859+
Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
2860+
let cppNamespace = "::mlir::NVVM";
2861+
let genSpecializedAttr = 0;
2862+
}
2863+
def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
2864+
let assemblyFormat = "`<` $value `>`";
2865+
}
2866+
2867+
def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
2868+
def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
2869+
2870+
def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
2871+
[FormatB6x16_P32, FormatB4x16_P64]> {
2872+
let cppNamespace = "::mlir::NVVM";
2873+
let genSpecializedAttr = 0;
2874+
}
2875+
def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
2876+
let assemblyFormat = "`<` $value `>`";
2877+
}
2878+
2879+
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
2880+
let summary = "Tcgen05 copy operation";
2881+
let description = [{
2882+
Instruction tcgen05.cp initiates an asynchronous copy operation from
2883+
shared memory to the location specified by the address operand `taddr`
2884+
in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
2885+
the matrix descriptor representing the source matrix in the shared memory
2886+
that needs to be copied.
2887+
2888+
usage:
2889+
nvvm.tcgen05.cp %taddr, %smem_desc {
2890+
group = #nvvm.tcgen05_group<cta_2>,
2891+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
2892+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
2893+
srcFormat = #nvvm.tcgen05_cp_format<b6x16_p32>
2894+
}
2895+
[For more information refer to the PTX ISA]
2896+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
2897+
}];
2898+
2899+
let arguments = (ins
2900+
Tcgen05CpShapeAttr:$shape,
2901+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
2902+
DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
2903+
OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
2904+
LLVM_PointerTensor:$taddr,
2905+
I64:$smem_desc);
2906+
2907+
let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
2908+
let hasVerifier = 1;
2909+
2910+
let extraClassDeclaration = [{
2911+
static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
2912+
}];
2913+
2914+
string llvmBuilder = [{
2915+
auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
2916+
createIntrinsicCall(builder, id, {$taddr, $smem_desc});
2917+
}];
2918+
}
2919+
28132920
//===----------------------------------------------------------------------===//
28142921
// NVVM target attribute.
28152922
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7575

7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

78+
//===----------------------------------------------------------------------===//
79+
// Verifier methods for NVVMDialect Ops
80+
//===----------------------------------------------------------------------===//
81+
7882
// This verifier is shared among the following Ops:
7983
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
8084
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
11071111
return success();
11081112
}
11091113

1114+
LogicalResult NVVM::Tcgen05CpOp::verify() {
1115+
auto mc = getMulticast();
1116+
1117+
using SH = Tcgen05CpShape;
1118+
using MC = Tcgen05CpMulticast;
1119+
switch (getShape()) {
1120+
case SH::SHAPE_128x256b:
1121+
case SH::SHAPE_128x128b:
1122+
case SH::SHAPE_4x256b:
1123+
if (mc != MC::NONE)
1124+
return emitError("Invalid multicast type for tcgen05.cp Op");
1125+
break;
1126+
case SH::SHAPE_64x128b:
1127+
if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1128+
return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1129+
"warpx2_02_13 for tcgen05.cp Op");
1130+
break;
1131+
case SH::SHAPE_32x128b:
1132+
if (mc != MC::WARPX4)
1133+
return emitError(
1134+
"Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1135+
break;
1136+
default:
1137+
return emitError("Invalid shape for tcgen05.cp Op");
1138+
}
1139+
return success();
1140+
}
1141+
1142+
//===----------------------------------------------------------------------===//
1143+
// NVVMDialect: getIntrinsicID/getIntrinsicIDAndArgs methods
1144+
//===----------------------------------------------------------------------===//
1145+
11101146
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
11111147
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
11121148

@@ -1314,6 +1350,47 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13141350
return id;
13151351
}
13161352

1353+
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1354+
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1355+
1356+
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1357+
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1358+
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1359+
1360+
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1361+
[&]() -> auto { \
1362+
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
1363+
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1364+
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
1365+
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1366+
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1367+
} \
1368+
()
1369+
1370+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1371+
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1372+
bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1373+
auto srcFmt = curOp.getSrcFormat();
1374+
auto mc = curOp.getMulticast();
1375+
1376+
switch (curOp.getShape()) {
1377+
case Tcgen05CpShape::SHAPE_128x256b:
1378+
return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1379+
case Tcgen05CpShape::SHAPE_128x128b:
1380+
return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1381+
case Tcgen05CpShape::SHAPE_4x256b:
1382+
return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1383+
case Tcgen05CpShape::SHAPE_32x128b:
1384+
return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1385+
case Tcgen05CpShape::SHAPE_64x128b:
1386+
return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1387+
? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1388+
: GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1389+
default:
1390+
llvm_unreachable("Invalid shape in tcgen05 cp Op");
1391+
}
1392+
}
1393+
13171394
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
13181395
/// have ConstantRangeAttr.
13191396
static void nvvmInferResultRanges(Operation *op, Value result,
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
3+
4+
// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
5+
llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
6+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %0, i64 %1)
7+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
8+
9+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %0, i64 %1)
10+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
11+
12+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
13+
nvvm.tcgen05.cp %taddr, %smem_desc {
14+
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
15+
group = #nvvm.tcgen05_group<cta_2>,
16+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
17+
}
18+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
19+
nvvm.tcgen05.cp %taddr, %smem_desc {
20+
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
21+
group = #nvvm.tcgen05_group<cta_2>,
22+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
23+
}
24+
llvm.return
25+
}
26+
27+
// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
28+
llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
29+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %0, i64 %1)
30+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
31+
32+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %0, i64 %1)
33+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
34+
35+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
36+
nvvm.tcgen05.cp %taddr, %smem_desc {
37+
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
38+
group = #nvvm.tcgen05_group<cta_2>,
39+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
40+
}
41+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
42+
nvvm.tcgen05.cp %taddr, %smem_desc {
43+
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
44+
group = #nvvm.tcgen05_group<cta_2>,
45+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
46+
}
47+
llvm.return
48+
}
49+
50+
// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
51+
llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
52+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %0, i64 %1)
53+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
54+
55+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %0, i64 %1)
56+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
57+
58+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
59+
nvvm.tcgen05.cp %taddr, %smem_desc {
60+
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
61+
group = #nvvm.tcgen05_group<cta_2>,
62+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
63+
}
64+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
65+
nvvm.tcgen05.cp %taddr, %smem_desc {
66+
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
67+
group = #nvvm.tcgen05_group<cta_2>,
68+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
69+
}
70+
llvm.return
71+
}
72+
73+
// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
74+
llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
75+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %0, i64 %1)
76+
nvvm.tcgen05.cp %taddr, %smem_desc {
77+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
78+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
79+
}
80+
81+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %0, i64 %1)
82+
nvvm.tcgen05.cp %taddr, %smem_desc {
83+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
84+
group = #nvvm.tcgen05_group<cta_2>,
85+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
86+
}
87+
88+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %0, i64 %1)
89+
nvvm.tcgen05.cp %taddr, %smem_desc {
90+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
91+
group = #nvvm.tcgen05_group<cta_1>,
92+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
93+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
94+
}
95+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
96+
nvvm.tcgen05.cp %taddr, %smem_desc {
97+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
98+
group = #nvvm.tcgen05_group<cta_2>,
99+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
100+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
101+
}
102+
103+
llvm.return
104+
}
105+
106+
// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
107+
llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
108+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %0, i64 %1)
109+
nvvm.tcgen05.cp %taddr, %smem_desc {
110+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
111+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
112+
}
113+
114+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %0, i64 %1)
115+
nvvm.tcgen05.cp %taddr, %smem_desc {
116+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
117+
group = #nvvm.tcgen05_group<cta_2>,
118+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
119+
}
120+
121+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
122+
nvvm.tcgen05.cp %taddr, %smem_desc {
123+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
124+
group = #nvvm.tcgen05_group<cta_2>,
125+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
126+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
127+
}
128+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %0, i64 %1)
129+
nvvm.tcgen05.cp %taddr, %smem_desc {
130+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
131+
group = #nvvm.tcgen05_group<cta_1>,
132+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
133+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
134+
}
135+
136+
llvm.return
137+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
3+
4+
// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
5+
llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
6+
// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
7+
nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
8+
9+
// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
10+
nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
11+
llvm.return
12+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
122122
%res = nvvm.cvt.float.to.tf32 %src
123123
llvm.return %res : i32
124124
}
125+
126+
// -----
127+
128+
llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
129+
// expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
130+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
131+
llvm.return
132+
}
133+
134+
// -----
135+
136+
llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
137+
// expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
138+
nvvm.tcgen05.cp %taddr, %smem_desc {
139+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
140+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
141+
}
142+
llvm.return
143+
}
144+
145+
// -----
146+
147+
llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
148+
// expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
149+
nvvm.tcgen05.cp %taddr, %smem_desc {
150+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
151+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
152+
}
153+
llvm.return
154+
}

0 commit comments

Comments
 (0)