Skip to content

Commit d756d86

Browse files
committed
[MLIR][NVVM] Add Op for TMA Store with reduction
PR #116854 adds intrinsics for TMA Store with reduction. This patch adds an NVVM Dialect Op for the same. Lit tests to verify the lowering to LLVM intrinsics as well as verifier tests (for invalid cases) are added. * The Common verifier method for TMA Ops is updated to handle im2col modes without offsets. This helps Ops like TMA Store, TMA StoreReduce etc. * The nvvmir.mlir test file is already large. So, this patch adds the tests for this Op into a separate file under a separate "nvvm/" directory. [mlir/test/Target/LLVMIR/"nvvm"/tma_store_reduce.mlir] PTX Spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor Signed-off-by: Durgadoss R <[email protected]>
1 parent 8ab2730 commit d756d86

File tree

4 files changed

+503
-9
lines changed

4 files changed

+503
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,107 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
20292029
}];
20302030
}
20312031

2032+
// List of modes supported for TMA Store and Reduction Ops
2033+
def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
2034+
def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2035+
2036+
def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
2037+
[TMAStoreModeTile, TMAStoreModeIm2Col]> {
2038+
let genSpecializedAttr = 0;
2039+
let cppNamespace = "::mlir::NVVM";
2040+
}
2041+
def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
2042+
let assemblyFormat = "`<` $value `>`";
2043+
}
2044+
2045+
// List of Reduction Ops supported with TMA Store
2046+
def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
2047+
def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
2048+
def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">;
2049+
def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">;
2050+
def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">;
2051+
def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">;
2052+
def TMAReduxKindOr : I32EnumAttrCase<"OR", 6, "or">;
2053+
def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">;
2054+
2055+
def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind",
2056+
[TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin,
2057+
TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd,
2058+
TMAReduxKindOr, TMAReduxKindXor]> {
2059+
let genSpecializedAttr = 0;
2060+
let cppNamespace = "::mlir::NVVM";
2061+
}
2062+
def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> {
2063+
let assemblyFormat = "`<` $value `>`";
2064+
}
2065+
2066+
def NVVM_CpAsyncBulkTensorReduceOp :
2067+
NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> {
2068+
let arguments = (ins
2069+
LLVM_AnyPointer:$tmaDescriptor,
2070+
LLVM_PointerShared:$srcMem,
2071+
TMAReduxKindAttr:$redKind,
2072+
DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
2073+
Variadic<I32>:$coordinates,
2074+
Optional<I64>:$l2CacheHint);
2075+
2076+
let description = [{
2077+
Initiates an asynchronous reduction operation of tensor data in
2078+
global memory with tensor data in shared memory.
2079+
2080+
The `mode` attribute indicates whether the copy mode is tile or im2col.
2081+
The `redOp` attribute specifies the reduction operations applied.
2082+
The supported reduction operations are:
2083+
{add, min, max, inc, dec, and, or, xor}
2084+
2085+
The `l2CacheHint` operand is optional, and it is used to specify cache
2086+
eviction policy that may be used during the memory access.
2087+
2088+
[For more information, see PTX ISA]
2089+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor)
2090+
}];
2091+
2092+
let assemblyFormat = [{
2093+
$tmaDescriptor `,`
2094+
$srcMem `,`
2095+
`box` `[`$coordinates `]`
2096+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2097+
attr-dict `:` type($tmaDescriptor) `,` type($srcMem)
2098+
}];
2099+
2100+
let extraClassDeclaration = [{
2101+
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
2102+
NVVM::TMAReduxKind kind,
2103+
bool isIm2Col);
2104+
}];
2105+
2106+
let hasVerifier = 1;
2107+
2108+
string llvmBuilder = [{
2109+
// Arguments to the intrinsic:
2110+
// shared_mem_ptr, tmaDesc, tensorDims
2111+
// cache_hint(if applicable) and flag(boolean)
2112+
llvm::SmallVector<llvm::Value *> translatedOperands;
2113+
translatedOperands.push_back($srcMem);
2114+
translatedOperands.push_back($tmaDescriptor);
2115+
2116+
for (auto v : op.getCoordinates())
2117+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2118+
2119+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2120+
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2121+
2122+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2123+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2124+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2125+
2126+
auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
2127+
op.getCoordinates().size(), $redKind,
2128+
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
2129+
createIntrinsicCall(builder, intId, translatedOperands);
2130+
}];
2131+
}
2132+
20322133
//===----------------------------------------------------------------------===//
20332134
// NVVM Wgmma Ops
20342135
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,30 +75,37 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7575

7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

78-
// This verifier is shared across:
79-
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80-
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
78+
// This verifier is shared among the following Ops:
79+
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80+
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81+
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
8182
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
83+
bool isIm2Col,
8284
size_t numIm2ColOffsets,
8385
Location loc) {
8486
if (tensorDims < 1 || tensorDims > 5)
8587
return emitError(loc, "expects coordinates between 1 to 5 dimension");
8688

87-
if (numIm2ColOffsets) {
89+
// For Im2Col mode, there are two constraints:
90+
if (isIm2Col) {
91+
// 1. Tensor must always be at least 3-d.
8892
if (tensorDims < 3)
8993
return emitError(
9094
loc,
9195
"to use im2col mode, the tensor has to be at least 3-dimensional");
92-
if (tensorDims != (numIm2ColOffsets + 2))
96+
// 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97+
if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
9398
return emitError(
9499
loc, "im2col offsets must be 2 less than number of coordinates");
95100
}
96101
return success();
97102
}
98103

99104
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
100-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
101-
getIm2colOffsets().size(), getLoc());
105+
size_t numIm2ColOffsets = getIm2colOffsets().size();
106+
bool isIm2Col = numIm2ColOffsets > 0;
107+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
108+
numIm2ColOffsets, getLoc());
102109
}
103110

104111
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
@@ -119,8 +126,16 @@ LogicalResult CpAsyncOp::verify() {
119126
}
120127

121128
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
122-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
123-
getIm2colOffsets().size(), getLoc());
129+
size_t numIm2ColOffsets = getIm2colOffsets().size();
130+
bool isIm2Col = numIm2ColOffsets > 0;
131+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
132+
numIm2ColOffsets, getLoc());
133+
}
134+
135+
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
136+
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
137+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
138+
getLoc());
124139
}
125140

126141
// Given the element type of an operand and whether or not it is an accumulator,
@@ -1094,6 +1109,55 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
10941109
}
10951110
}
10961111

1112+
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1113+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_##op##_##mode##_##dim##d
1114+
1115+
#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1116+
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1117+
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1118+
1119+
#define GET_CP_ASYNC_BULK_TENSOR_RED_ID(op, dims, is_im2col) \
1120+
[&]() -> auto { \
1121+
switch (dims) { \
1122+
case 1: \
1123+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1124+
case 2: \
1125+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1126+
case 3: \
1127+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1128+
case 4: \
1129+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1130+
case 5: \
1131+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1132+
default: \
1133+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1134+
} \
1135+
}()
1136+
1137+
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1138+
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1139+
using RedTy = NVVM::TMAReduxKind;
1140+
switch (kind) {
1141+
case RedTy::ADD:
1142+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(add, tensorDims, isIm2Col);
1143+
case RedTy::MIN:
1144+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(min, tensorDims, isIm2Col);
1145+
case RedTy::MAX:
1146+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(max, tensorDims, isIm2Col);
1147+
case RedTy::INC:
1148+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(inc, tensorDims, isIm2Col);
1149+
case RedTy::DEC:
1150+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(dec, tensorDims, isIm2Col);
1151+
case RedTy::AND:
1152+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(and, tensorDims, isIm2Col);
1153+
case RedTy::OR:
1154+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(or, tensorDims, isIm2Col);
1155+
case RedTy::XOR:
1156+
return GET_CP_ASYNC_BULK_TENSOR_RED_ID(xor, tensorDims, isIm2Col);
1157+
}
1158+
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1159+
}
1160+
10971161
//===----------------------------------------------------------------------===//
10981162
// NVVMDialect initialization, type parsing, and registration.
10991163
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)