Skip to content

[MLIR][NVVM] Add Op for TMA Store with reduction #118853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2024

Conversation

durga4github
Copy link
Contributor

PR #116854 adds intrinsics for TMA Store with reduction.
This patch adds an NVVM Dialect Op for the same.

  • Lit tests are added to verify the lowering to LLVM intrinsics and invalid cases.
  • The common verifier method is updated to handle im2col modes without offsets.
    This helps Ops like TMA Store, TMA StoreReduce etc.
  • The nvvmir.mlir test file is already large. So, this patch adds the tests for this Op
    in a new file under a separate "nvvm/" directory.
    [mlir/test/Target/LLVMIR/"nvvm"/tma_store_reduce.mlir]

PTX Spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

PR #116854 adds intrinsics for TMA Store with reduction.
This patch adds an NVVM Dialect Op for the same.

  • Lit tests are added to verify the lowering to LLVM intrinsics and invalid cases.
  • The common verifier method is updated to handle im2col modes without offsets.
    This helps Ops like TMA Store, TMA StoreReduce etc.
  • The nvvmir.mlir test file is already large. So, this patch adds the tests for this Op
    in a new file under a separate "nvvm/" directory.
    [mlir/test/Target/LLVMIR/"nvvm"/tma_store_reduce.mlir]

PTX Spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor


Patch is 57.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118853.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+101)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+73-9)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir (+313)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 296a3c305e5bf4..14880a1a66ba57 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2029,6 +2029,107 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
   }];
 }
 
+// List of modes supported for TMA Store and Reduction Ops
+def TMAStoreModeTile   : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+
+def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
+    [TMAStoreModeTile, TMAStoreModeIm2Col]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// List of Reduction Ops supported with TMA Store
+def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
+def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
+def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">;
+def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">;
+def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">;
+def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">;
+def TMAReduxKindOr  : I32EnumAttrCase<"OR",  6, "or">;
+def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">;
+
+def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind",
+    [TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin,
+     TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd,
+     TMAReduxKindOr,  TMAReduxKindXor]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CpAsyncBulkTensorReduceOp :
+  NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> {
+  let arguments = (ins
+    LLVM_AnyPointer:$tmaDescriptor,
+    LLVM_PointerShared:$srcMem,
+    TMAReduxKindAttr:$redKind,
+    DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
+    Variadic<I32>:$coordinates,
+    Optional<I64>:$l2CacheHint);
+
+  let description = [{
+    Initiates an asynchronous reduction operation of tensor data in
+    global memory with tensor data in shared memory.
+
+    The `mode` attribute indicates whether the copy mode is tile or im2col.
+    The `redOp` attribute specifies the reduction operations applied.
+    The supported reduction operations are:
+    {add, min, max, inc, dec, and, or, xor}
+
+    The `l2CacheHint` operand is optional, and it is used to specify cache
+    eviction policy that may be used during the memory access.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor)
+  }];
+
+  let assemblyFormat = [{
+    $tmaDescriptor `,`
+    $srcMem `,`
+    `box` `[`$coordinates `]`
+    (`l2_cache_hint` `=` $l2CacheHint^ )?
+    attr-dict  `:` type($tmaDescriptor) `,` type($srcMem)
+  }];
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
+                                              NVVM::TMAReduxKind kind,
+                                              bool isIm2Col);
+  }];
+
+  let hasVerifier = 1;
+
+  string llvmBuilder = [{
+    // Arguments to the intrinsic:
+    // shared_mem_ptr, tmaDesc, tensorDims
+    // cache_hint(if applicable) and flag(boolean)
+    llvm::SmallVector<llvm::Value *> translatedOperands;
+    translatedOperands.push_back($srcMem);
+    translatedOperands.push_back($tmaDescriptor);
+
+    for (auto v : op.getCoordinates())
+      translatedOperands.push_back(moduleTranslation.lookupValue(v));
+
+    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
+    auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
+
+    bool isCacheHint = op.getL2CacheHint() ? true : false;
+    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
+    translatedOperands.push_back(builder.getInt1(isCacheHint));
+
+    auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
+                 op.getCoordinates().size(), $redKind,
+                 (op.getMode() == NVVM::TMAStoreMode::IM2COL));
+    createIntrinsicCall(builder, intId, translatedOperands);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ca04af0b060b4f..d8a9d513aa858b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,21 +75,26 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
 
 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
 
-// This verifier is shared across:
-// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
+// This verifier is shared among the following Ops:
+// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
+// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
+// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
+                                                     bool isIm2Col,
                                                      size_t numIm2ColOffsets,
                                                      Location loc) {
   if (tensorDims < 1 || tensorDims > 5)
     return emitError(loc, "expects coordinates between 1 to 5 dimension");
 
-  if (numIm2ColOffsets) {
+  // For Im2Col mode, there are two constraints:
+  if (isIm2Col) {
+    // 1. Tensor must always be at least 3-d.
     if (tensorDims < 3)
       return emitError(
           loc,
           "to use im2col mode, the tensor has to be at least 3-dimensional");
-    if (tensorDims != (numIm2ColOffsets + 2))
+    // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
+    if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
       return emitError(
           loc, "im2col offsets must be 2 less than number of coordinates");
   }
@@ -97,8 +102,10 @@ static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
 }
 
 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
-  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
-                                         getIm2colOffsets().size(), getLoc());
+  size_t numIm2ColOffsets = getIm2colOffsets().size();
+  bool isIm2Col = numIm2ColOffsets > 0;
+  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
+                                         numIm2ColOffsets, getLoc());
 }
 
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
@@ -119,8 +126,16 @@ LogicalResult CpAsyncOp::verify() {
 }
 
 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
-  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
-                                         getIm2colOffsets().size(), getLoc());
+  size_t numIm2ColOffsets = getIm2colOffsets().size();
+  bool isIm2Col = numIm2ColOffsets > 0;
+  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
+                                         numIm2ColOffsets, getLoc());
+}
+
+LogicalResult CpAsyncBulkTensorReduceOp::verify() {
+  bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
+  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
+                                         getLoc());
 }
 
 // Given the element type of an operand and whether or not it is an accumulator,
@@ -1094,6 +1109,55 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
   }
 }
 
+#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode)                        \
+  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_##op##_##mode##_##dim##d
+
+#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col)                        \
+  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col)                \
+            : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
+
+#define GET_CP_ASYNC_BULK_TENSOR_RED_ID(op, dims, is_im2col)                   \
+  [&]() -> auto {                                                              \
+    switch (dims) {                                                            \
+    case 1:                                                                    \
+      return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile);                    \
+    case 2:                                                                    \
+      return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile);                    \
+    case 3:                                                                    \
+      return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col);                    \
+    case 4:                                                                    \
+      return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col);                    \
+    case 5:                                                                    \
+      return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col);                    \
+    default:                                                                   \
+      llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp.");     \
+    }                                                                          \
+  }()
+
+llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
+    int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
+  using RedTy = NVVM::TMAReduxKind;
+  switch (kind) {
+  case RedTy::ADD:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(add, tensorDims, isIm2Col);
+  case RedTy::MIN:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(min, tensorDims, isIm2Col);
+  case RedTy::MAX:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(max, tensorDims, isIm2Col);
+  case RedTy::INC:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(inc, tensorDims, isIm2Col);
+  case RedTy::DEC:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(dec, tensorDims, isIm2Col);
+  case RedTy::AND:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(and, tensorDims, isIm2Col);
+  case RedTy::OR:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(or, tensorDims, isIm2Col);
+  case RedTy::XOR:
+    return GET_CP_ASYNC_BULK_TENSOR_RED_ID(xor, tensorDims, isIm2Col);
+  }
+  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir
new file mode 100644
index 00000000000000..3809bc0bce8974
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir
@@ -0,0 +1,313 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: define void @tma_store_reduce_1d(
+llvm.func @tma_store_reduce_1d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.1d(ptr addrspace(3) %[[SRC:.*]], ptr %[[DST:.*]], i32 %[[D0:.*]], i64 %[[CH:.*]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 %[[CH]], i1 true)
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<add>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<min>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<max>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<inc>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<dec>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<and>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<or>}  : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<xor>} : !llvm.ptr, !llvm.ptr<3>
+
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false)
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<add>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<min>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<max>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<inc>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<dec>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<or>, mode = #nvvm.tma_store_mode<tile>}  : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<xor>, mode = #nvvm.tma_store_mode<tile>} : !llvm.ptr, !llvm.ptr<3>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: define void @tma_store_reduce_2d(
+llvm.func @tma_store_reduce_2d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.2d(ptr addrspace(3) %[[SRC:.*]], ptr %[[DST:.*]], i32 %[[D0:.*]], i32 %[[D1:.*]], i64 %[[CH:.*]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 %[[CH]], i1 true)
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<add>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<min>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<max>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<inc>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<dec>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<and>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<or>}  : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind<xor>} : !llvm.ptr, !llvm.ptr<3>
+
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false)
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<add>} : !llvm.ptr, !llvm.ptr<3>
+  nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind...
[truncated]

@durga4github
Copy link
Contributor Author

@grypp , Could you please help with a review?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a broad note

let assemblyFormat = "`<` $value `>`";
}

def NVVM_CpAsyncBulkTensorReduceOp :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, as a general design note - coming from the AMD side - I'd argue that there should be a NVVM operation that's an exact 1:1 map to the LLVM intrinsic and an operation in some other dialect - NVGPU,, I think, might be it, that takes the enum and has the user-friendly syntax.

Then, the conversion to LLVM becomes a MLIR rewrite pattern instead of hiding non-trivial logic in LLVMBUilder

Copy link
Collaborator

@joker-eph joker-eph Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM intrinsics don't support polymorphism, quite unfortunately...
The NVVM dialect is embracing this though which makes it much more "MLIR friendly".
It becomes 1:M in the translation to LLVM IR, where "name mangling" is required by selecting the right LLVM intrinsic.
However that is still a purely "mechanical" translation, which remains trivial and can be completely round-tripped (there is no complex logic involved here).

Adding another layer in MLIR is doable but:

  1. it's heavy in terms of engineering time
  2. it's non-trivial in terms of compile-time impact (and this code is on the path of latency sensitive JITs).

(in the same vein, we discussed at the Dev Summit the addition of direct translation from SCF/arith to LLVM IR to avoid paying extra dialect conversion)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Krzysztof, for your thoughts!

As you probably observed, this Op models one PTX instruction family only.
As Mehdi also pointed out, we had to model this with many intrinsics at
the LLVM level (mainly) due to the differing arguments for each variant.

Since MLIR infra offers enhanced type support like Variadic<>, Optional<>
etc, and the existing TMA Ops here already leverage these features,
I followed the same approach to model this Op too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I'm not going to get in the way of existing code and convention here. And perhaps LLVM intrinsics isn't exactly the right word here - the approach I've taken with AMDGPU/ROCDL is that, when there's a "prickly" feature (like, on our side, DPP, where, there're a lot of magic constants, or MFMA with its implicit type conversions), you get a more user-friendly wrapper under amdgpu.* and the raw underlying primitives that trivially translate to the export format in rocdl.*.

I was advocating for that approach because I figure it's both better developer experience and because it improves debugability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right that this is a more "pure" design, I tend to look at the tradeoff in terms of compile-time on this (dialect conversion is non-trivial and that means one extra pass in the pipeline).
The developer experience should be identical I believe, and the debuggability only come from the fact that you don't have the 1:M expansion in MLIR but during the translation. Since this is basically a single switch statement dispatching it, the impact is quite minimal hopefully (there isn't much difference in observing the effect of this 1:M mapping with mlir-opt as MLIR->MLIR transformation vs mlir-translate as MLIR->LLVM IR from a debugging standpoint, it's all in the same level of abstraction and a very local transformation 1 op -> 1 instruction).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose another way to phrase things is that I expect mlir-translate to be a fairly simple operation - though something like this sort of switch statement for mapping an MLIR op + attributes to one of several underlying operations is probably fine. That is, I'd be surprised by a "translation" that has a bunch of complex logic in it, chipset-dependence, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design philosophy for NVGPU and NVVM has long been clear and practical:

NVGPU was created to serve as a bridge between high-level dialects like memref and vector and the NVVM dialect.

NVVM was designed to generate PTX code, but this was never 1:1 mapping. It followed a 1:N mapping, but this N is always the same PTX instruction with different traits. We've relied on this extensively.

NVGPU can generate N NVVM OP, and here N are distinct PTX instructions, e.g., tensor core : fence + mma + commit + wait chain.

Also, NVGPU is good place when something doesn’t fit neatly into GPU dialects—such as nvidia specific driver calls.

[RFC] Add NV-GPU dialect

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 , I see where you're coming from regarding moving the enum to the upper dialect. It might be a nice idea for a small ISA, but our situation is a bit different. We have OPs with multiple enums, which would mean creating numerous NVVM ops and moving the enums to NVGPU dialect. Even if we want to do that, this approach doesn’t add much benefit. Unfortunately, the ROI is not clear to me.

PR llvm#116854 adds intrinsics for TMA Store with reduction.
This patch adds an NVVM Dialect Op for the same.

Lit tests to verify the lowering to LLVM intrinsics as well as
verifier tests (for invalid cases) are added.

* The Common verifier method for TMA Ops is updated
  to handle im2col modes without offsets. This helps
  Ops like TMA Store, TMA StoreReduce etc.
* The nvvmir.mlir test file is already large. So, this
  patch adds the tests for this Op into a separate file
  under a separate "nvvm/" directory.
  [mlir/test/Target/LLVMIR/"nvvm"/tma_store_reduce.mlir]

PTX Spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor

Signed-off-by: Durgadoss R <[email protected]>
@durga4github durga4github force-pushed the durgadossr/mlir_nvvm_tma_reduce branch from d756d86 to e9e26c0 Compare December 6, 2024 10:47
@durga4github
Copy link
Contributor Author

Fixed build issues on Windows in the latest revision,

@joker-eph joker-eph merged commit b4b819c into llvm:main Dec 11, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants