Skip to content

[NVPTX] Add cta_group support to TMA G2S intrinsics #143178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ Syntax:

.. code-block:: llvm

declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i32 %flag_cta_group)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.2d(..., i32 %d0, i32 %d1, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
Expand All @@ -1034,18 +1034,26 @@ source tensor is preserved at the destination. The dimension of the
tensor data ranges from 1d to 5d with the coordinates specified
by the ``i32 %d0 ... i32 %d4`` arguments.

* The last two arguments to these intrinsics are boolean flags
indicating support for cache_hint and/or multicast modifiers.
These flag arguments must be compile-time constants. The backend
looks through these flags and lowers the intrinsics appropriately.
* The last three arguments to these intrinsics are flags
indicating support for multicast, cache_hint and cta_group::1/2
modifiers. These flag arguments must be compile-time constants.
The backend looks through these flags and lowers the intrinsics
appropriately.

* The Nth argument (denoted by ``i1 flag_ch``) when set, indicates
* The argument denoted by ``i1 %flag_ch`` when set, indicates
a valid cache_hint (``i64 %ch``) and generates the ``.L2::cache_hint``
variant of the PTX instruction.

* The [N-1]th argument (denoted by ``i1 flag_mc``) when set, indicates
the presence of a multicast mask (``i16 %mc``) and generates the PTX
instruction with the ``.multicast::cluster`` modifier.
* The argument denoted by ``i1 %flag_mc`` when set, indicates
the presence of a multicast mask (``i16 %mc``) and generates
the PTX instruction with the ``.multicast::cluster`` modifier.

* The argument denoted by ``i32 %flag_cta_group`` takes values within
the range [0, 3) i.e. {0,1,2}. When the value of ``%flag_cta_group``
is not within the range, it may raise an error from the Verifier.
The default value is '0' with no cta_group modifier in the
instruction. The values of '1' and '2' lower to ``cta_group::1``
and ``cta_group::2`` variants of the PTX instruction respectively.

For more information, refer PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`_.
Expand All @@ -1058,7 +1066,7 @@ Syntax:

.. code-block:: llvm

declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(3) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i16 %im2col0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(3) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i16 %im2col0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i32 %flag_cta_group)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i16 %im2col0, i16 %im2col1, ...)
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, i16 %im2col0, i16 %im2col1, i16 %im2col2, ...)

Expand All @@ -1074,8 +1082,8 @@ are unrolled into a single dimensional column at the destination. In this
mode, the tensor has to be at least three-dimensional. Along with the tensor
coordinates, im2col offsets are also specified (denoted by
``i16 im2col0...i16 %im2col2``). The number of im2col offsets is two less
than the number of dimensions of the tensor operation. The last two arguments
to these intrinsics are boolean flags, with the same functionality as described
than the number of dimensions of the tensor operation. The last three arguments
to these intrinsics are flags, with the same functionality as described
in the ``tile`` mode intrinsics above.

For more information, refer PTX ISA
Expand Down
32 changes: 19 additions & 13 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -2020,20 +2020,26 @@ foreach dim = 1...5 in {
defvar num_im2col_offsets = !if(is_im2col, !add(dim, -2), 0);
defvar im2col_offsets_args = !listsplat(llvm_i16_ty, num_im2col_offsets);

defvar g2s_params = !listconcat(
[llvm_shared_cluster_ptr_ty, // dst_ptr
llvm_shared_ptr_ty, // mbarrier_ptr
llvm_ptr_ty], // tensormap_ptr
tensor_dim_args, // actual tensor dims
im2col_offsets_args, // im2col offsets
[llvm_i16_ty, // cta_mask
llvm_i64_ty]); // cache_hint
defvar g2s_flags = [llvm_i1_ty, // Flag for cta_mask
llvm_i1_ty, // Flag for cache_hint
llvm_i32_ty]; // Flag for cta_group
defvar cta_group_idx = !add(
!size(g2s_params),
!sub(!size(g2s_flags), 1));
defvar g2s_props = [IntrConvergent,
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
// Allowed values for cta_group are {0,1,2} i.e [0, 3).
Range<ArgIndex<cta_group_idx>, 0, 3>];
def int_nvvm_cp_async_bulk_tensor_g2s_ # mode # _ # dim # d :
DefaultAttrsIntrinsicFlags<[],
!listconcat([llvm_shared_cluster_ptr_ty, // dst_shared_cluster_ptr
llvm_shared_ptr_ty, // mbarrier_smem_ptr
llvm_ptr_ty], // tensormap_ptr
tensor_dim_args, // actual tensor dims
im2col_offsets_args, // im2col offsets
[llvm_i16_ty, // cta_mask
llvm_i64_ty]), // cache_hint
[llvm_i1_ty, // Flag for cta_mask
llvm_i1_ty], // Flag for cache_hint
[IntrConvergent,
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>]>;
DefaultAttrsIntrinsicFlags<[], g2s_params, g2s_flags, g2s_props>;

def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
DefaultAttrsIntrinsicFlags<[],
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ enum class TMAReductionOp : uint8_t {
XOR = 7,
};

// Enum to represent the cta_group::1 and
// cta_group::2 variants in TMA/TCGEN05 family of
// PTX instructions.
enum class CTAGroupKind : uint8_t {
CG_NONE = 0, // default with no cta_group modifier
CG_1 = 1, // cta_group::1 modifier
CG_2 = 2, // cta_group::2 modifier
};

inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_f2i_rm_ftz:
Expand Down
104 changes: 83 additions & 21 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
return false; // No other 'arm.*', 'aarch64.*'.
}

static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics(Function *F,
StringRef Name) {
if (Name.consume_front("cp.async.bulk.tensor.g2s.")) {
Intrinsic::ID ID =
StringSwitch<Intrinsic::ID>(Name)
.Case("im2col.3d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
.Case("im2col.4d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
.Case("im2col.5d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
.Case("tile.1d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
.Case("tile.2d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
.Case("tile.3d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
.Case("tile.4d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
.Case("tile.5d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
.Default(Intrinsic::not_intrinsic);

if (ID == Intrinsic::not_intrinsic)
return ID;

// These intrinsics may need upgrade for two reasons:
// (1) When the address-space of the first argument is shared[AS=3]
// (and we upgrade it to use shared_cluster address-space[AS=7])
if (F->getArg(0)->getType()->getPointerAddressSpace() ==
NVPTXAS::ADDRESS_SPACE_SHARED)
return ID;

// (2) When there are only two boolean flag arguments at the end:
//
// The last three parameters of the older version of these
// intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
//
// The newer version reads as:
// arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i32 cta_group_flag
//
// So, when the type of the [N-3]rd argument is "not i1", then
// it is the older version and we need to upgrade.
size_t FlagStartIndex = F->getFunctionType()->getNumParams() - 3;
Type *ArgType = F->getFunctionType()->getParamType(FlagStartIndex);
if (!ArgType->isIntegerTy(1))
return ID;
}

return Intrinsic::not_intrinsic;
}

static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
StringRef Name) {
if (Name.consume_front("mapa.shared.cluster"))
Expand All @@ -959,22 +1006,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
.Case("shared.cta.to.cluster",
Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
.Case("tensor.g2s.im2col.3d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
.Case("tensor.g2s.im2col.4d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
.Case("tensor.g2s.im2col.5d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
.Case("tensor.g2s.tile.1d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
.Case("tensor.g2s.tile.2d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
.Case("tensor.g2s.tile.3d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
.Case("tensor.g2s.tile.4d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
.Case("tensor.g2s.tile.5d",
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
.Default(Intrinsic::not_intrinsic);

if (ID != Intrinsic::not_intrinsic)
Expand Down Expand Up @@ -1339,6 +1370,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
return true;
}

// Upgrade TMA copy G2S Intrinsics
IID = shouldUpgradeNVPTXTMAG2SIntrinsics(F, Name);
if (IID != Intrinsic::not_intrinsic) {
rename(F);
NewFn = Intrinsic::getOrInsertDeclaration(F->getParent(), IID);
return true;
}

// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
//
Expand Down Expand Up @@ -4831,7 +4870,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
return;
}
case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
// Create a new call with the correct address space.
SmallVector<Value *, 4> Args(CI->args());
Args[0] = Builder.CreateAddrSpaceCast(
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));

NewCall = Builder.CreateCall(NewFn, Args);
NewCall->takeName(CI);
CI->replaceAllUsesWith(NewCall);
CI->eraseFromParent();
return;
}
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
Expand All @@ -4840,10 +4890,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
// Create a new call with the correct address space.
SmallVector<Value *, 4> Args(CI->args());
Args[0] = Builder.CreateAddrSpaceCast(
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
SmallVector<Value *, 16> Args(CI->args());

// Create AddrSpaceCast to shared_cluster if needed.
// This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
unsigned AS = CI->getArgOperand(0)->getType()->getPointerAddressSpace();
if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
Args[0] = Builder.CreateAddrSpaceCast(
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));

// Attach the flag argument for cta_group, with a
// default value of 0. This handles case (2) in
// shouldUpgradeNVPTXTMAG2SIntrinsics().
size_t NumArgs = CI->arg_size();
Value *FlagArg = CI->getArgOperand(NumArgs - 3);
if (!FlagArg->getType()->isIntegerTy(1))
Args.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));

NewCall = Builder.CreateCall(NewFn, Args);
NewCall->takeName(CI);
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,22 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
llvm_unreachable(
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
}

void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
raw_ostream &O) {
const MCOperand &MO = MI->getOperand(OpNum);
using CGTy = nvvm::CTAGroupKind;

switch (static_cast<CGTy>(MO.getImm())) {
case CGTy::CG_NONE:
O << "";
return;
case CGTy::CG_1:
O << ".cta_group::1";
return;
case CGTy::CG_2:
O << ".cta_group::2";
return;
}
llvm_unreachable("Invalid cta_group in printCTAGroup");
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
};

}
Expand Down
19 changes: 14 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2556,19 +2556,25 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
// {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
// multicast, cache_hint,
// multicast_flag, cache_hint_flag}
// multicast_flag, cache_hint_flag, cta_group_flag}
// NumOperands = {Chain, IID} + {Actual intrinsic args}
// = {2} + {7 + dims + im2col_offsets}
// = {2} + {8 + dims + im2col_offsets}
size_t NumOps = N->getNumOperands();
size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
: (NumOps - 9);
: (NumOps - 10);
// Offsets is always 'NumDims - 2' and only for im2col mode
size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1;
size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID

unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1);
if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport())
report_fatal_error(
formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}",
Subtarget->getSmVersion()));

SDLoc DL(N);
SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));

Expand All @@ -2580,6 +2586,9 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
if (IsCacheHint)
Ops.push_back(N->getOperand(MultiCastIdx + 1));

// Flag for CTA Group
Ops.push_back(getI32Imm(CTAGroupVal, DL));

// Finally, the chain operand
Ops.push_back(N->getOperand(0));

Expand Down
17 changes: 12 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,14 @@ class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
# !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
}

def CTAGroupFlags : Operand<i32> {
let PrintMethod = "printCTAGroup";
}

multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
defvar dims_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "d" # i));
defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
defvar asm_str_default = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
defvar rc = !if(is_shared32, Int32Regs, Int64Regs);

defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
Expand All @@ -595,19 +599,22 @@ multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode>
!strconcat(asm_str_default, im2col_asm_str), asm_str_default);

def "" : NVPTXInst<(outs),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
!strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";"), []>,
Requires<[hasPTX<80>, hasSM<90>]>;
def _MC : NVPTXInst<(outs),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc)),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
(ins Int16Regs:$mc, CTAGroupFlags:$cg)),
!strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;"), []>,
Requires<[hasPTX<80>, hasSM<90>]>;
def _CH : NVPTXInst<(outs),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int64Regs:$ch)),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
(ins Int64Regs:$ch, CTAGroupFlags:$cg)),
!strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;"), []>,
Requires<[hasPTX<80>, hasSM<90>]>;
def _MC_CH : NVPTXInst<(outs),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc, Int64Regs:$ch)),
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
(ins Int16Regs:$mc, Int64Regs:$ch, CTAGroupFlags:$cg)),
!strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;"), []>,
Requires<[hasPTX<80>, hasSM<90>]>;
}
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
return HasTcgen05 && PTXVersion >= 86;
}

// TMA G2S copy with cta_group::1/2 support
bool hasCpAsyncBulkTensorCTAGroupSupport() const {
// TODO: Update/tidy-up after the family-conditional support arrives
return ((FullSmVersion == 1001 || FullSmVersion == 1011) &&
PTXVersion >= 86) ||
(FullSmVersion == 1031 && PTXVersion >= 88);
}

// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
// terminates a basic block. Instead, it would assume that control flow
// continued to the next instruction. The next instruction could be in the
Expand Down
Loading