Skip to content

[MLIR][NVVM] Add TMA Bulk Copy Ops #123186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025

Conversation

durga4github
Copy link
Contributor

@durga4github durga4github commented Jan 16, 2025

PR #122344 adds intrinsics for Bulk Async Copy
(non-tensor variants) using TMA. This patch
adds the corresponding NVVM Dialect Ops.

lit tests are added to verify the lowering to all
variants of the intrinsics.

PR llvm#122344 adds intrinsics for Bulk Async Copy
(non-tensor variants) using TMA. This patch
adds the corresponding NVVM Dialect Ops.

Signed-off-by: Durgadoss R <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

PR #122344 adds intrinsics for Bulk Async Copy
(non-tensor variants) using TMA. This patch
adds the corresponding NVVM Dialect Ops.


Full diff: https://github.com/llvm/llvm-project/pull/123186.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+144)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 04042903e343ed..8300d13c16820c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2138,6 +2138,150 @@ def NVVM_CpAsyncBulkTensorReduceOp :
   }];
 }
 
+def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
+  NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
+  let summary = "Async bulk copy from global memory to Shared cluster memory";
+  let description = [{
+    Initiates an asynchronous copy operation from global memory to cluster's
+    shared memory.
+
+    The `multicastMask` operand is optional. When it is present, the Op copies
+    data from global memory to shared memory of multiple CTAs in the cluster.
+    Operand `multicastMask` specifies the destination CTAs in the cluster such
+    that each bit position in the 16-bit `multicastMask` operand corresponds to
+    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
+
+    The `l2CacheHint` operand is optional, and it is used to specify cache
+    eviction policy that may be used during the memory access.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
+  }];
+
+  let arguments = (ins
+    LLVM_PointerShared:$dstMem,
+    LLVM_PointerGlobal:$srcMem,
+    LLVM_PointerShared:$mbar,
+    I32:$size,
+    Optional<I16>:$multicastMask,
+    Optional<I64>:$l2CacheHint);
+
+  let assemblyFormat = [{
+    $dstMem `,` $srcMem `,` $mbar `,` $size
+    (`multicast_mask` `=` $multicastMask^ )?
+    (`l2_cache_hint` `=` $l2CacheHint^ )?
+    attr-dict  `:` type($dstMem) `,` type($srcMem)
+  }];
+
+  string llvmBuilder = [{
+    // Arguments to the intrinsic:
+    // dst, mbar, src, size
+    // multicast_mask, cache_hint,
+    // flag for multicast_mask,
+    // flag for cache_hint
+    llvm::SmallVector<llvm::Value *> translatedOperands;
+    translatedOperands.push_back($dstMem);
+    translatedOperands.push_back($mbar);
+    translatedOperands.push_back($srcMem);
+    translatedOperands.push_back($size);
+
+    // Multicast, if available
+    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
+    auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
+    bool isMulticast = op.getMulticastMask() ? true : false;
+    translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
+
+    // Cachehint, if available
+    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
+    bool isCacheHint = op.getL2CacheHint() ? true : false;
+    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
+
+    // Flag arguments for multicast and cachehint
+    translatedOperands.push_back(builder.getInt1(isMulticast));
+    translatedOperands.push_back(builder.getInt1(isCacheHint));
+
+    createIntrinsicCall(builder,
+      llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
+  }];
+}
+
+def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
+  NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> {
+  let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory";
+  let description = [{
+    Initiates an asynchronous copy operation from Shared CTA memory to Shared
+    cluster memory.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
+  }];
+
+  let arguments = (ins
+    LLVM_PointerShared:$dstMem,
+    LLVM_PointerShared:$srcMem,
+    LLVM_PointerShared:$mbar,
+    I32:$size);
+
+  let assemblyFormat = [{
+    $dstMem `,` $srcMem `,` $mbar `,` $size
+    attr-dict  `:` type($dstMem) `,` type($srcMem)
+  }];
+
+  string llvmBuilder = [{
+    createIntrinsicCall(builder,
+      llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster,
+      {$dstMem, $mbar, $srcMem, $size});
+  }];
+}
+
+def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
+  NVVM_Op<"cp.async.bulk.global.shared.cta"> {
+  let summary = "Async bulk copy from Shared CTA memory to Global memory";
+  let description = [{
+    Initiates an asynchronous copy operation from Shared CTA memory to
+    global memory.
+
+    The `l2CacheHint` operand is optional, and it is used to specify cache
+    eviction policy that may be used during the memory access.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
+  }];
+
+  let arguments = (ins
+    LLVM_PointerGlobal:$dstMem,
+    LLVM_PointerShared:$srcMem,
+    I32:$size,
+    Optional<I64>:$l2CacheHint);
+
+  let assemblyFormat = [{
+    $dstMem `,` $srcMem `,` $size
+    (`l2_cache_hint` `=` $l2CacheHint^ )?
+    attr-dict  `:` type($dstMem) `,` type($srcMem)
+  }];
+
+  string llvmBuilder = [{
+    // Arguments to the intrinsic:
+    // dst, src, size, cache_hint,
+    // Flag for cache_hint
+    //
+    llvm::SmallVector<llvm::Value *> translatedOperands;
+    translatedOperands.push_back($dstMem);
+    translatedOperands.push_back($srcMem);
+    translatedOperands.push_back($size);
+
+    // Cachehint, if available
+    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
+    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
+    bool isCacheHint = op.getL2CacheHint() ? true : false;
+    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
+
+    // Flag argument for cachehint
+    translatedOperands.push_back(builder.getInt1(isCacheHint));
+
+    createIntrinsicCall(builder,
+      llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
new file mode 100644
index 00000000000000..aa2d680f5117e8
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cluster
+llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %mc : i16, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i16 0, i64 0, i1 false, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 0, i64 %[[CH:.*]], i1 false, i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC:.*]], i64 0, i1 true, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cluster(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i16 %[[MC]], i64 %[[CH]], i1 true, i1 true)
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
+
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc : !llvm.ptr<3>, !llvm.ptr<1>
+
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size multicast_mask = %mc l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+  llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
+llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<3>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(3) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
+  nvvm.cp.async.bulk.shared.cluster.shared.cta %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<3>
+  llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
+llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
+  nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
+
+  nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
+  llvm.return
+}

@durga4github
Copy link
Contributor Author

@grypp , Please help with review

@durga4github durga4github merged commit 0f9e913 into llvm:main Jan 21, 2025
11 checks passed
@durga4github durga4github deleted the durgadossr/mlir_tma_copy branch January 21, 2025 08:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants