Skip to content

[NVPTX] Add cta_group support to TMA G2S intrinsics #143178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025

Conversation

durga4github
Copy link
Contributor

@durga4github durga4github commented Jun 6, 2025

This patch extends the TMA G2S intrinsics with the
support for cta_group::1/2 available from Blackwell onwards.
The existing intrinsics are auto-upgraded with a default
value of '0' for the cta_group flag operand.

  • lit tests are added for all combinations of the newer variants.
  • Negative tests are added to validate the error-handling
    when the value of the cta_group flag falls out-of-range.
  • The generated PTX is verified with a 12.8 ptxas executable.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-nvptx

Author: Durgadoss R (durga4github)

Changes

This patch extends the TMA G2S intrinsics with the
2-CTA mode support available from Blackwell onwards.
The existing intrinsics are auto-upgraded with a default
value of '0' for the is_2cta_mode flag operand.

lit tests are added for all combinations of the new variant.
The generated PTX is verified with a 12.8 ptxas executable.


Patch is 58.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143178.diff

9 Files Affected:

  • (modified) llvm/docs/NVPTXUsage.rst (+13-9)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+3-3)
  • (modified) llvm/lib/IR/AutoUpgrade.cpp (+83-21)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+7)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+14-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+12-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+8)
  • (added) llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-g2s-2cta.ll (+435)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 8bb0f2ed17c32..ec73939345731 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1016,7 +1016,7 @@ Syntax:
 
 .. code-block:: llvm
 
-  declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
+  declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i1 %flag_cta_group)
   declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.2d(..., i32 %d0, i32 %d1, ...)
   declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
   declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
@@ -1034,18 +1034,22 @@ source tensor is preserved at the destination. The dimension of the
 tensor data ranges from 1d to 5d with the coordinates specified
 by the ``i32 %d0 ... i32 %d4`` arguments.
 
-* The last two arguments to these intrinsics are boolean flags
-  indicating support for cache_hint and/or multicast modifiers.
-  These flag arguments must be compile-time constants. The backend
-  looks through these flags and lowers the intrinsics appropriately.
+* The last three arguments to these intrinsics are boolean flags
+  indicating support for multicast, cache_hint and cta_group::2
+  modifiers. These flag arguments must be compile-time constants.
+  The backend looks through these flags and lowers the intrinsics
+  appropriately.
 
-* The Nth argument (denoted by ``i1 flag_ch``) when set, indicates
+* The argument denoted by ``i1 flag_ch`` when set, indicates
   a valid cache_hint (``i64 %ch``) and generates the ``.L2::cache_hint``
   variant of the PTX instruction.
 
-* The [N-1]th argument (denoted by ``i1 flag_mc``) when set, indicates
-  the presence of a multicast mask (``i16 %mc``) and generates the PTX
-  instruction with the ``.multicast::cluster`` modifier.
+* The argument denoted by ``i1 flag_mc`` when set, indicates
+  the presence of a multicast mask (``i16 %mc``) and generates
+  the PTX instruction with the ``.multicast::cluster`` modifier.
+
+* The argument denoted by ``i1 flag_cta_group`` when set, generates
+  the ``.cta_group::2`` variant of the PTX instruction.
 
 For more information, refer PTX ISA
 `<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`_.
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 91e7d188c8533..1f08b282bcc3b 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2025,10 +2025,10 @@ foreach dim = 1...5 in {
                       [llvm_i16_ty,                 // cta_mask
                        llvm_i64_ty]),               // cache_hint
           [llvm_i1_ty,                              // Flag for cta_mask
-           llvm_i1_ty],                             // Flag for cache_hint
+           llvm_i1_ty,                              // Flag for cache_hint
+           llvm_i1_ty],                             // Flag for is_2cta_mode
           [IntrConvergent,
-           WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
-           NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>]>;
+           WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>;
 
     def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
       DefaultAttrsIntrinsicFlags<[],
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 7ba6d411bc7b5..2fc29bbb83bbc 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -939,6 +939,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
   return false; // No other 'arm.*', 'aarch64.*'.
 }
 
+static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics(Function *F,
+                                                        StringRef Name) {
+  if (Name.consume_front("cp.async.bulk.tensor.g2s.")) {
+    Intrinsic::ID ID =
+        StringSwitch<Intrinsic::ID>(Name)
+            .Case("im2col.3d",
+                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
+            .Case("im2col.4d",
+                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
+            .Case("im2col.5d",
+                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
+            .Case("tile.1d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
+            .Case("tile.2d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
+            .Case("tile.3d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
+            .Case("tile.4d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
+            .Case("tile.5d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
+            .Default(Intrinsic::not_intrinsic);
+
+    if (ID == Intrinsic::not_intrinsic)
+      return ID;
+
+    // These intrinsics may need upgrade for two reasons:
+    // (1) When the address-space of the first argument is shared[AS=3]
+    //     (and we upgrade it to use shared_cluster address-space[AS=7])
+    if (F->getArg(0)->getType()->getPointerAddressSpace() ==
+        NVPTXAS::ADDRESS_SPACE_SHARED)
+      return ID;
+
+    // (2) When there are only two boolean flag arguments at the end:
+    //
+    // The last three parameters of the older version of these
+    // intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
+    //
+    // The newer version has three boolean flags at the end:
+    // arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i1 cta_group_flag
+    //
+    // So, when the type of the [N-3]rd argument is "not i1", then
+    // it is the older version and we need to upgrade.
+    size_t FlagStartIndex = F->getFunctionType()->getNumParams() - 3;
+    Type *ArgType = F->getFunctionType()->getParamType(FlagStartIndex);
+    if (!ArgType->isIntegerTy(1))
+      return ID;
+  }
+
+  return Intrinsic::not_intrinsic;
+}
+
 static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
                                                               StringRef Name) {
   if (Name.consume_front("mapa.shared.cluster"))
@@ -953,22 +1000,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
                   Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
             .Case("shared.cta.to.cluster",
                   Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
-            .Case("tensor.g2s.im2col.3d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
-            .Case("tensor.g2s.im2col.4d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
-            .Case("tensor.g2s.im2col.5d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
-            .Case("tensor.g2s.tile.1d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
-            .Case("tensor.g2s.tile.2d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
-            .Case("tensor.g2s.tile.3d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
-            .Case("tensor.g2s.tile.4d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
-            .Case("tensor.g2s.tile.5d",
-                  Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
             .Default(Intrinsic::not_intrinsic);
 
     if (ID != Intrinsic::not_intrinsic)
@@ -1334,6 +1365,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
         return true;
       }
 
+      // Upgrade TMA copy G2S Intrinsics
+      IID = shouldUpgradeNVPTXTMAG2SIntrinsics(F, Name);
+      if (IID != Intrinsic::not_intrinsic) {
+        rename(F);
+        NewFn = Intrinsic::getOrInsertDeclaration(F->getParent(), IID);
+        return true;
+      }
+
       // The following nvvm intrinsics correspond exactly to an LLVM idiom, but
       // not to an intrinsic alone.  We expand them in UpgradeIntrinsicCall.
       //
@@ -4813,7 +4852,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
     return;
   }
   case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
-  case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
+  case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
+    // Create a new call with the correct address space.
+    SmallVector<Value *, 4> Args(CI->args());
+    Args[0] = Builder.CreateAddrSpaceCast(
+        Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
+
+    NewCall = Builder.CreateCall(NewFn, Args);
+    NewCall->takeName(CI);
+    CI->replaceAllUsesWith(NewCall);
+    CI->eraseFromParent();
+    return;
+  }
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
@@ -4822,10 +4872,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
-    // Create a new call with the correct address space.
-    SmallVector<Value *, 4> Args(CI->args());
-    Args[0] = Builder.CreateAddrSpaceCast(
-        Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
+    SmallVector<Value *, 16> Args(CI->args());
+
+    // Create AddrSpaceCast to shared_cluster if needed.
+    // This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
+    unsigned AS = CI->getArgOperand(0)->getType()->getPointerAddressSpace();
+    if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
+      Args[0] = Builder.CreateAddrSpaceCast(
+          Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
+
+    // Attach the flag argument for cta_group, with a
+    // default value of 0. This handles case (2) in
+    // shouldUpgradeNVPTXTMAG2SIntrinsics().
+    size_t NumArgs = CI->arg_size();
+    Value *FlagArg = CI->getArgOperand(NumArgs - 3);
+    if (!FlagArg->getType()->isIntegerTy(1))
+      Args.push_back(ConstantInt::get(Builder.getInt1Ty(), 0));
 
     NewCall = Builder.CreateCall(NewFn, Args);
     NewCall->takeName(CI);
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b4616b64bad15..268f713a660b9 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -437,3 +437,10 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
   llvm_unreachable(
       "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
 }
+
+void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
+                                     raw_ostream &O) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  if (MO.getImm())
+    O << ".cta_group::2";
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index a2dd772cd86d0..f73af7a3f2c6e 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -51,6 +51,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
   void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
+  void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 32223bf3d601e..24929610fe5e4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2556,19 +2556,25 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
   // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
   // multicast, cache_hint,
-  // multicast_flag, cache_hint_flag}
+  // multicast_flag, cache_hint_flag, 2cta_mode_flag}
   // NumOperands = {Chain, IID} + {Actual intrinsic args}
-  //             = {2}          + {7 + dims + im2col_offsets}
+  //             = {2}          + {8 + dims + im2col_offsets}
   size_t NumOps = N->getNumOperands();
   size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
-                            : (NumOps - 9);
+                            : (NumOps - 10);
   // Offsets is always 'NumDims - 2' and only for im2col mode
   size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
-  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
-  bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
+  bool Is2CTAMode = N->getConstantOperandVal(NumOps - 1) == 1;
+  bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
+  bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1;
   size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
   size_t MultiCastIdx = NumBaseArgs + 2;         // for Chain and IID
 
+  if (Is2CTAMode && !Subtarget->hasCpAsyncBulkTensor2CTASupport())
+    report_fatal_error(
+        formatv("CpAsyncBulkTensorG2S 2CTA mode is not supported on sm_{}",
+                Subtarget->getSmVersion()));
+
   SDLoc DL(N);
   SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
 
@@ -2580,6 +2586,9 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
   if (IsCacheHint)
     Ops.push_back(N->getOperand(MultiCastIdx + 1));
 
+  // Flag for 2-CTA mode
+  Ops.push_back(CurDAG->getTargetConstant(Is2CTAMode, DL, MVT::i1));
+
   // Finally, the chain operand
   Ops.push_back(N->getOperand(0));
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index d14c03791febb..e53e806816cf3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -593,10 +593,14 @@ class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
                      # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
 }
 
+def CTAGroupFlags : Operand<i1> {
+  let PrintMethod = "printCTAGroup";
+}
+
 multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
   defvar dims_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "d" # i));
   defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
-  defvar asm_str_default = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
+  defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
   defvar rc = !if(is_shared32, Int32Regs, Int64Regs);
 
   defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
@@ -610,19 +614,22 @@ multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode>
     !strconcat(asm_str_default, im2col_asm_str), asm_str_default);
 
   def NAME: NVPTXInst<(outs),
-            !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag),
+            !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
             !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";"), []>,
             Requires<[hasPTX<80>, hasSM<90>]>;
   def NAME # _MC: NVPTXInst<(outs),
-                  !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc)),
+                  !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
+                       (ins Int16Regs:$mc, CTAGroupFlags:$cg)),
                   !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;"), []>,
                   Requires<[hasPTX<80>, hasSM<90>]>;
   def NAME # _CH: NVPTXInst<(outs),
-                  !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int64Regs:$ch)),
+                  !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
+                       (ins Int64Regs:$ch, CTAGroupFlags:$cg)),
                   !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;"), []>,
                   Requires<[hasPTX<80>, hasSM<90>]>;
   def NAME # _MC_CH: NVPTXInst<(outs),
-                     !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc, Int64Regs:$ch)),
+                     !con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
+                          (ins Int16Regs:$mc, Int64Regs:$ch, CTAGroupFlags:$cg)),
                      !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;"), []>,
                      Requires<[hasPTX<80>, hasSM<90>]>;
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 5136b1ee28502..8e4b96e9ac380 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -117,6 +117,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
     return HasTcgen05 && PTXVersion >= 86;
   }
 
+  // TMA G2S copy 2cta mode support
+  bool hasCpAsyncBulkTensor2CTASupport() const {
+    // TODO: Update/tidy-up after the family-conditional support arrives
+    return ((FullSmVersion == 1001 || FullSmVersion == 1011) &&
+            PTXVersion >= 86) ||
+           (FullSmVersion == 1031 && PTXVersion >= 88);
+  }
+
   // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
   // terminates a basic block. Instead, it would assume that control flow
   // continued to the next instruction. The next instruction could be in the
diff --git a/llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-g2s-2cta.ll b/llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-g2s-2cta.ll
new file mode 100644
index 0000000000000..d8b2d42c579e5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-g2s-2cta.ll
@@ -0,0 +1,435 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86| FileCheck --check-prefixes=CHECK-PTX64 %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 --nvptx-short-ptr| FileCheck --check-prefixes=CHECK-PTX-SHARED32 %s
+; RUN: %if ptxas-12.3 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86| %ptxas-verify -arch=sm_100a %}
+; RUN: %if ptxas-12.3 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 --nvptx-short-ptr| %ptxas-verify -arch=sm_100a %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.2d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.3d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.4d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.5d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i16 %im2col0, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.4d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i16 %im2col0, i16 %im2col1, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.5d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, i16 %im2col0, i16 %im2col1, i16 %im2col2, i16 %mc, i64 %ch, i1 %f1, i1 %f2, i1 %f3);
+
+; CHECK-LABEL: test_cp_async_bulk_tensor_g2s_tile_1d
+define void @test_cp_async_bulk_tensor_g2s_tile_1d(ptr addrspace(7) %d, ptr addrspace(3) %bar, ptr %tmap, i32 %d0, i16 %mc, i64 %ch) {
+; CHECK-PTX64-LABEL: test_cp_async_...
[truncated]

@durga4github durga4github force-pushed the durgadossr/nvptx_tma_load_2cta branch from 203f2b6 to b85959b Compare June 9, 2025 14:22
@durga4github durga4github changed the title [NVPTX] Add 2-CTA mode support to TMA G2S intrinsics [NVPTX] Add cta_group support to TMA G2S intrinsics Jun 9, 2025
@durga4github durga4github force-pushed the durgadossr/nvptx_tma_load_2cta branch 2 times, most recently from 2a48e6b to 7d4012d Compare June 11, 2025 11:46
This patch extends the TMA G2S intrinsics with
the 2-CTA mode support available from Blackwell
onwards. The existing intrinsics are auto-upgraded with a
default value of '0' for the `cta_group` flag.

lit tests are added for all combinations of the new
variant. The generated PTX is verified with a 12.8
ptxas executable.

Signed-off-by: Durgadoss R <[email protected]>
@durga4github durga4github force-pushed the durgadossr/nvptx_tma_load_2cta branch from 7d4012d to f09fdd4 Compare June 12, 2025 07:45
@durga4github durga4github merged commit 3e5d50f into llvm:main Jun 12, 2025
8 checks passed
@durga4github durga4github deleted the durgadossr/nvptx_tma_load_2cta branch June 12, 2025 09:50
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jun 12, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-flang-rhel-clang running on ppc64le-flang-rhel-test while building llvm at step 6 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/30587

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-flang) failure: test (failure)
******************** TEST 'Flang :: Semantics/modfile75.F90' FAILED ********************
Exit Code: 2

Command Output (stderr):
--
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 | /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 # RUN: at line 1
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
error: Semantic errors in /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90:15:11: error: Must be a constant value
    integer(c_int) n
            ^^^^^
FileCheck error: '<stdin>' is empty.
FileCheck command line:  /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90

--

********************


tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
This patch extends the TMA G2S intrinsics with the
support for cta_group::1/2 available from Blackwell onwards.
The existing intrinsics are auto-upgraded with a default
value of '0' for the `cta_group` flag operand.

* lit tests are added for all combinations of the newer variants.
* Negative tests are added to validate the error-handling 
   when the value of the cta_group flag falls out-of-range.
* The generated PTX is verified with a 12.8 ptxas executable.

Signed-off-by: Durgadoss R <[email protected]>
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
This patch extends the TMA G2S intrinsics with the
support for cta_group::1/2 available from Blackwell onwards.
The existing intrinsics are auto-upgraded with a default
value of '0' for the `cta_group` flag operand.

* lit tests are added for all combinations of the newer variants.
* Negative tests are added to validate the error-handling 
   when the value of the cta_group flag falls out-of-range.
* The generated PTX is verified with a 12.8 ptxas executable.

Signed-off-by: Durgadoss R <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants