Skip to content

[MLIR][NVVM] Add Op for TMA Store with reduction #118853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,107 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}];
}

// List of modes supported for TMA Store and Reduction Ops
def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;

def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
[TMAStoreModeTile, TMAStoreModeIm2Col]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
let assemblyFormat = "`<` $value `>`";
}

// List of Reduction Ops supported with TMA Store
def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">;
def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">;
def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">;
def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">;
def TMAReduxKindOr : I32EnumAttrCase<"OR", 6, "or">;
def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">;

def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind",
[TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin,
TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd,
TMAReduxKindOr, TMAReduxKindXor]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CpAsyncBulkTensorReduceOp :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, as a general design note - coming from the AMD side - I'd argue that there should be a NVVM operation that's an exact 1:1 map to the LLVM intrinsic and an operation in some other dialect - NVGPU,, I think, might be it, that takes the enum and has the user-friendly syntax.

Then, the conversion to LLVM becomes a MLIR rewrite pattern instead of hiding non-trivial logic in LLVMBUilder

Copy link
Collaborator

@joker-eph joker-eph Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM intrinsics don't support polymorphism, quite unfortunately...
The NVVM dialect is embracing this though which makes it much more "MLIR friendly".
It becomes 1:M in the translation to LLVM IR, where "name mangling" is required by selecting the right LLVM intrinsic.
However that is still a purely "mechanical" translation, which remains trivial and can be completely round-tripped (there is no complex logic involved here).

Adding another layer in MLIR is doable but:

  1. it's heavy in terms of engineering time
  2. it's non-trivial in terms of compile-time impact (and this code is on the path of latency sensitive JITs).

(in the same vein, we discussed at the Dev Summit the addition of direct translation from SCF/arith to LLVM IR to avoid paying extra dialect conversion)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Krzysztof, for your thoughts!

As you probably observed, this Op models one PTX instruction family only.
As Mehdi also pointed out, we had to model this with many intrinsics at
the LLVM level (mainly) due to the differing arguments for each variant.

Since MLIR infra offers enhanced type support like Variadic<>, Optional<>
etc, and the existing TMA Ops here already leverage these features,
I followed the same approach to model this Op too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I'm not going to get in the way of existing code and convention here. And perhaps LLVM intrinsics isn't exactly the right word here - the approach I've taken with AMDGPU/ROCDL is that, when there's a "prickly" feature (like, on our side, DPP, where, there're a lot of magic constants, or MFMA with its implicit type conversions), you get a more user-friendly wrapper under amdgpu.* and the raw underlying primitives that trivially translate to the export format in rocdl.*.

I was advocating for that approach because I figure it's both better developer experience and because it improves debugability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right that this is a more "pure" design, I tend to look at the tradeoff in terms of compile-time on this (dialect conversion is non-trivial and that means one extra pass in the pipeline).
The developer experience should be identical I believe, and the debuggability only come from the fact that you don't have the 1:M expansion in MLIR but during the translation. Since this is basically a single switch statement dispatching it, the impact is quite minimal hopefully (there isn't much difference in observing the effect of this 1:M mapping with mlir-opt as MLIR->MLIR transformation vs mlir-translate as MLIR->LLVM IR from a debugging standpoint, it's all in the same level of abstraction and a very local transformation 1 op -> 1 instruction).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose another way to phrase things is that I expect mlir-translate to be a fairly simple operation - though something like this sort of switch statement for mapping an MLIR op + attributes to one of several underlying operations is probably fine. That is, I'd be surprised by a "translation" that has a bunch of complex logic in it, chipset-dependence, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design philosophy for NVGPU and NVVM has long been clear and practical:

NVGPU was created to serve as a bridge between high-level dialects like memref and vector and the NVVM dialect.

NVVM was designed to generate PTX code, but this was never 1:1 mapping. It followed a 1:N mapping, but this N is always the same PTX instruction with different traits. We've relied on this extensively.

NVGPU can generate N NVVM OP, and here N are distinct PTX instructions, e.g., tensor core : fence + mma + commit + wait chain.

Also, NVGPU is good place when something doesn’t fit neatly into GPU dialects—such as nvidia specific driver calls.

[RFC] Add NV-GPU dialect

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 , I see where you're coming from regarding moving the enum to the upper dialect. It might be a nice idea for a small ISA, but our situation is a bit different. We have OPs with multiple enums, which would mean creating numerous NVVM ops and moving the enums to NVGPU dialect. Even if we want to do that, this approach doesn’t add much benefit. Unfortunately, the ROI is not clear to me.

NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> {
let arguments = (ins
LLVM_AnyPointer:$tmaDescriptor,
LLVM_PointerShared:$srcMem,
TMAReduxKindAttr:$redKind,
DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
Variadic<I32>:$coordinates,
Optional<I64>:$l2CacheHint);

let description = [{
Initiates an asynchronous reduction operation of tensor data in
global memory with tensor data in shared memory.

The `mode` attribute indicates whether the copy mode is tile or im2col.
The `redOp` attribute specifies the reduction operations applied.
The supported reduction operations are:
{add, min, max, inc, dec, and, or, xor}

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor)
}];

let assemblyFormat = [{
$tmaDescriptor `,`
$srcMem `,`
`box` `[`$coordinates `]`
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($tmaDescriptor) `,` type($srcMem)
}];

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
NVVM::TMAReduxKind kind,
bool isIm2Col);
}];

let hasVerifier = 1;

string llvmBuilder = [{
// Arguments to the intrinsic:
// shared_mem_ptr, tmaDesc, tensorDims
// cache_hint(if applicable) and flag(boolean)
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($srcMem);
translatedOperands.push_back($tmaDescriptor);

for (auto v : op.getCoordinates())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));

bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
translatedOperands.push_back(builder.getInt1(isCacheHint));

auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
op.getCoordinates().size(), $redKind,
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
createIntrinsicCall(builder, intId, translatedOperands);
}];
}

//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
Expand Down
82 changes: 73 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,37 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

// This verifier is shared across:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
size_t numIm2ColOffsets,
Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");

if (numIm2ColOffsets) {
// For Im2Col mode, there are two constraints:
if (isIm2Col) {
// 1. Tensor must always be at least 3-d.
if (tensorDims < 3)
return emitError(
loc,
"to use im2col mode, the tensor has to be at least 3-dimensional");
if (tensorDims != (numIm2ColOffsets + 2))
// 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
return emitError(
loc, "im2col offsets must be 2 less than number of coordinates");
}
return success();
}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
}

LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
Expand All @@ -119,8 +126,16 @@ LogicalResult CpAsyncOp::verify() {
}

LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
}

LogicalResult CpAsyncBulkTensorReduceOp::verify() {
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
getLoc());
}

// Given the element type of an operand and whether or not it is an accumulator,
Expand Down Expand Up @@ -1094,6 +1109,55 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
}
}

#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d

#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)

#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
[&]() -> auto { \
switch (dims) { \
case 1: \
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
case 2: \
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
case 3: \
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
case 4: \
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
case 5: \
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
default: \
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
} \
}()

llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
using RedTy = NVVM::TMAReduxKind;
switch (kind) {
case RedTy::ADD:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
case RedTy::MIN:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
case RedTy::MAX:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
case RedTy::INC:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
case RedTy::DEC:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
case RedTy::AND:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
case RedTy::OR:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
case RedTy::XOR:
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
}
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading