Skip to content

Commit 3e5d50f

Browse files
authored
[NVPTX] Add cta_group support to TMA G2S intrinsics (#143178)
This patch extends the TMA G2S intrinsics with the support for cta_group::1/2 available from Blackwell onwards. The existing intrinsics are auto-upgraded with a default value of '0' for the `cta_group` flag operand. * lit tests are added for all combinations of the newer variants. * Negative tests are added to validate the error-handling when the value of the cta_group flag falls out-of-range. * The generated PTX is verified with a 12.8 ptxas executable. Signed-off-by: Durgadoss R <[email protected]>
1 parent 8e4fdff commit 3e5d50f

13 files changed

+1078
-64
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ Syntax:
10161016

10171017
.. code-block:: llvm
10181018
1019-
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
1019+
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i32 %flag_cta_group)
10201020
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.2d(..., i32 %d0, i32 %d1, ...)
10211021
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
10221022
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
@@ -1034,18 +1034,26 @@ source tensor is preserved at the destination. The dimension of the
10341034
tensor data ranges from 1d to 5d with the coordinates specified
10351035
by the ``i32 %d0 ... i32 %d4`` arguments.
10361036

1037-
* The last two arguments to these intrinsics are boolean flags
1038-
indicating support for cache_hint and/or multicast modifiers.
1039-
These flag arguments must be compile-time constants. The backend
1040-
looks through these flags and lowers the intrinsics appropriately.
1037+
* The last three arguments to these intrinsics are flags
1038+
indicating support for multicast, cache_hint and cta_group::1/2
1039+
modifiers. These flag arguments must be compile-time constants.
1040+
The backend looks through these flags and lowers the intrinsics
1041+
appropriately.
10411042

1042-
* The Nth argument (denoted by ``i1 flag_ch``) when set, indicates
1043+
* The argument denoted by ``i1 %flag_ch`` when set, indicates
10431044
a valid cache_hint (``i64 %ch``) and generates the ``.L2::cache_hint``
10441045
variant of the PTX instruction.
10451046

1046-
* The [N-1]th argument (denoted by ``i1 flag_mc``) when set, indicates
1047-
the presence of a multicast mask (``i16 %mc``) and generates the PTX
1048-
instruction with the ``.multicast::cluster`` modifier.
1047+
* The argument denoted by ``i1 %flag_mc`` when set, indicates
1048+
the presence of a multicast mask (``i16 %mc``) and generates
1049+
the PTX instruction with the ``.multicast::cluster`` modifier.
1050+
1051+
* The argument denoted by ``i32 %flag_cta_group`` takes values within
1052+
the range [0, 3) i.e. {0,1,2}. When the value of ``%flag_cta_group``
1053+
is not within the range, it may raise an error from the Verifier.
1054+
The default value is '0' with no cta_group modifier in the
1055+
instruction. The values of '1' and '2' lower to ``cta_group::1``
1056+
and ``cta_group::2`` variants of the PTX instruction respectively.
10491057

10501058
For more information, refer PTX ISA
10511059
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`_.
@@ -1058,7 +1066,7 @@ Syntax:
10581066

10591067
.. code-block:: llvm
10601068
1061-
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(3) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i16 %im2col0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
1069+
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(3) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i32 %d1, i32 %d2, i16 %im2col0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i32 %flag_cta_group)
10621070
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i16 %im2col0, i16 %im2col1, ...)
10631071
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.5d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, i16 %im2col0, i16 %im2col1, i16 %im2col2, ...)
10641072
@@ -1074,8 +1082,8 @@ are unrolled into a single dimensional column at the destination. In this
10741082
mode, the tensor has to be at least three-dimensional. Along with the tensor
10751083
coordinates, im2col offsets are also specified (denoted by
10761084
``i16 im2col0...i16 %im2col2``). The number of im2col offsets is two less
1077-
than the number of dimensions of the tensor operation. The last two arguments
1078-
to these intrinsics are boolean flags, with the same functionality as described
1085+
than the number of dimensions of the tensor operation. The last three arguments
1086+
to these intrinsics are flags, with the same functionality as described
10791087
in the ``tile`` mode intrinsics above.
10801088

10811089
For more information, refer PTX ISA

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,20 +2020,26 @@ foreach dim = 1...5 in {
20202020
defvar num_im2col_offsets = !if(is_im2col, !add(dim, -2), 0);
20212021
defvar im2col_offsets_args = !listsplat(llvm_i16_ty, num_im2col_offsets);
20222022

2023+
defvar g2s_params = !listconcat(
2024+
[llvm_shared_cluster_ptr_ty, // dst_ptr
2025+
llvm_shared_ptr_ty, // mbarrier_ptr
2026+
llvm_ptr_ty], // tensormap_ptr
2027+
tensor_dim_args, // actual tensor dims
2028+
im2col_offsets_args, // im2col offsets
2029+
[llvm_i16_ty, // cta_mask
2030+
llvm_i64_ty]); // cache_hint
2031+
defvar g2s_flags = [llvm_i1_ty, // Flag for cta_mask
2032+
llvm_i1_ty, // Flag for cache_hint
2033+
llvm_i32_ty]; // Flag for cta_group
2034+
defvar cta_group_idx = !add(
2035+
!size(g2s_params),
2036+
!sub(!size(g2s_flags), 1));
2037+
defvar g2s_props = [IntrConvergent,
2038+
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
2039+
// Allowed values for cta_group are {0,1,2} i.e [0, 3).
2040+
Range<ArgIndex<cta_group_idx>, 0, 3>];
20232041
def int_nvvm_cp_async_bulk_tensor_g2s_ # mode # _ # dim # d :
2024-
DefaultAttrsIntrinsicFlags<[],
2025-
!listconcat([llvm_shared_cluster_ptr_ty, // dst_shared_cluster_ptr
2026-
llvm_shared_ptr_ty, // mbarrier_smem_ptr
2027-
llvm_ptr_ty], // tensormap_ptr
2028-
tensor_dim_args, // actual tensor dims
2029-
im2col_offsets_args, // im2col offsets
2030-
[llvm_i16_ty, // cta_mask
2031-
llvm_i64_ty]), // cache_hint
2032-
[llvm_i1_ty, // Flag for cta_mask
2033-
llvm_i1_ty], // Flag for cache_hint
2034-
[IntrConvergent,
2035-
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
2036-
NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>]>;
2042+
DefaultAttrsIntrinsicFlags<[], g2s_params, g2s_flags, g2s_props>;
20372043

20382044
def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
20392045
DefaultAttrsIntrinsicFlags<[],

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ enum class TMAReductionOp : uint8_t {
3838
XOR = 7,
3939
};
4040

41+
// Enum to represent the cta_group::1 and
42+
// cta_group::2 variants in TMA/TCGEN05 family of
43+
// PTX instructions.
44+
enum class CTAGroupKind : uint8_t {
45+
CG_NONE = 0, // default with no cta_group modifier
46+
CG_1 = 1, // cta_group::1 modifier
47+
CG_2 = 2, // cta_group::2 modifier
48+
};
49+
4150
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
4251
switch (IntrinsicID) {
4352
case Intrinsic::nvvm_f2i_rm_ftz:

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
945945
return false; // No other 'arm.*', 'aarch64.*'.
946946
}
947947

948+
static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics(Function *F,
949+
StringRef Name) {
950+
if (Name.consume_front("cp.async.bulk.tensor.g2s.")) {
951+
Intrinsic::ID ID =
952+
StringSwitch<Intrinsic::ID>(Name)
953+
.Case("im2col.3d",
954+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
955+
.Case("im2col.4d",
956+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
957+
.Case("im2col.5d",
958+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
959+
.Case("tile.1d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
960+
.Case("tile.2d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
961+
.Case("tile.3d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
962+
.Case("tile.4d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
963+
.Case("tile.5d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
964+
.Default(Intrinsic::not_intrinsic);
965+
966+
if (ID == Intrinsic::not_intrinsic)
967+
return ID;
968+
969+
// These intrinsics may need upgrade for two reasons:
970+
// (1) When the address-space of the first argument is shared[AS=3]
971+
// (and we upgrade it to use shared_cluster address-space[AS=7])
972+
if (F->getArg(0)->getType()->getPointerAddressSpace() ==
973+
NVPTXAS::ADDRESS_SPACE_SHARED)
974+
return ID;
975+
976+
// (2) When there are only two boolean flag arguments at the end:
977+
//
978+
// The last three parameters of the older version of these
979+
// intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
980+
//
981+
// The newer version reads as:
982+
// arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i32 cta_group_flag
983+
//
984+
// So, when the type of the [N-3]rd argument is "not i1", then
985+
// it is the older version and we need to upgrade.
986+
size_t FlagStartIndex = F->getFunctionType()->getNumParams() - 3;
987+
Type *ArgType = F->getFunctionType()->getParamType(FlagStartIndex);
988+
if (!ArgType->isIntegerTy(1))
989+
return ID;
990+
}
991+
992+
return Intrinsic::not_intrinsic;
993+
}
994+
948995
static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
949996
StringRef Name) {
950997
if (Name.consume_front("mapa.shared.cluster"))
@@ -959,22 +1006,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
9591006
Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
9601007
.Case("shared.cta.to.cluster",
9611008
Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
962-
.Case("tensor.g2s.im2col.3d",
963-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
964-
.Case("tensor.g2s.im2col.4d",
965-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
966-
.Case("tensor.g2s.im2col.5d",
967-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
968-
.Case("tensor.g2s.tile.1d",
969-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
970-
.Case("tensor.g2s.tile.2d",
971-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
972-
.Case("tensor.g2s.tile.3d",
973-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
974-
.Case("tensor.g2s.tile.4d",
975-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
976-
.Case("tensor.g2s.tile.5d",
977-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
9781009
.Default(Intrinsic::not_intrinsic);
9791010

9801011
if (ID != Intrinsic::not_intrinsic)
@@ -1339,6 +1370,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
13391370
return true;
13401371
}
13411372

1373+
// Upgrade TMA copy G2S Intrinsics
1374+
IID = shouldUpgradeNVPTXTMAG2SIntrinsics(F, Name);
1375+
if (IID != Intrinsic::not_intrinsic) {
1376+
rename(F);
1377+
NewFn = Intrinsic::getOrInsertDeclaration(F->getParent(), IID);
1378+
return true;
1379+
}
1380+
13421381
// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
13431382
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
13441383
//
@@ -4831,7 +4870,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48314870
return;
48324871
}
48334872
case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
4834-
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
4873+
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
4874+
// Create a new call with the correct address space.
4875+
SmallVector<Value *, 4> Args(CI->args());
4876+
Args[0] = Builder.CreateAddrSpaceCast(
4877+
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4878+
4879+
NewCall = Builder.CreateCall(NewFn, Args);
4880+
NewCall->takeName(CI);
4881+
CI->replaceAllUsesWith(NewCall);
4882+
CI->eraseFromParent();
4883+
return;
4884+
}
48354885
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
48364886
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
48374887
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
@@ -4840,10 +4890,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48404890
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
48414891
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
48424892
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
4843-
// Create a new call with the correct address space.
4844-
SmallVector<Value *, 4> Args(CI->args());
4845-
Args[0] = Builder.CreateAddrSpaceCast(
4846-
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4893+
SmallVector<Value *, 16> Args(CI->args());
4894+
4895+
// Create AddrSpaceCast to shared_cluster if needed.
4896+
// This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
4897+
unsigned AS = CI->getArgOperand(0)->getType()->getPointerAddressSpace();
4898+
if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
4899+
Args[0] = Builder.CreateAddrSpaceCast(
4900+
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4901+
4902+
// Attach the flag argument for cta_group, with a
4903+
// default value of 0. This handles case (2) in
4904+
// shouldUpgradeNVPTXTMAG2SIntrinsics().
4905+
size_t NumArgs = CI->arg_size();
4906+
Value *FlagArg = CI->getArgOperand(NumArgs - 3);
4907+
if (!FlagArg->getType()->isIntegerTy(1))
4908+
Args.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
48474909

48484910
NewCall = Builder.CreateCall(NewFn, Args);
48494911
NewCall->takeName(CI);

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,22 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
437437
llvm_unreachable(
438438
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
439439
}
440+
441+
void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
442+
raw_ostream &O) {
443+
const MCOperand &MO = MI->getOperand(OpNum);
444+
using CGTy = nvvm::CTAGroupKind;
445+
446+
switch (static_cast<CGTy>(MO.getImm())) {
447+
case CGTy::CG_NONE:
448+
O << "";
449+
return;
450+
case CGTy::CG_1:
451+
O << ".cta_group::1";
452+
return;
453+
case CGTy::CG_2:
454+
O << ".cta_group::2";
455+
return;
456+
}
457+
llvm_unreachable("Invalid cta_group in printCTAGroup");
458+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
5151
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
5252
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
5353
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
54+
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
5455
};
5556

5657
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,19 +2556,25 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
25562556
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
25572557
// {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
25582558
// multicast, cache_hint,
2559-
// multicast_flag, cache_hint_flag}
2559+
// multicast_flag, cache_hint_flag, cta_group_flag}
25602560
// NumOperands = {Chain, IID} + {Actual intrinsic args}
2561-
// = {2} + {7 + dims + im2col_offsets}
2561+
// = {2} + {8 + dims + im2col_offsets}
25622562
size_t NumOps = N->getNumOperands();
25632563
size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
2564-
: (NumOps - 9);
2564+
: (NumOps - 10);
25652565
// Offsets is always 'NumDims - 2' and only for im2col mode
25662566
size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
2567-
bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
2568-
bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
2567+
bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
2568+
bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1;
25692569
size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
25702570
size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID
25712571

2572+
unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1);
2573+
if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport())
2574+
report_fatal_error(
2575+
formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}",
2576+
Subtarget->getSmVersion()));
2577+
25722578
SDLoc DL(N);
25732579
SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
25742580

@@ -2580,6 +2586,9 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
25802586
if (IsCacheHint)
25812587
Ops.push_back(N->getOperand(MultiCastIdx + 1));
25822588

2589+
// Flag for CTA Group
2590+
Ops.push_back(getI32Imm(CTAGroupVal, DL));
2591+
25832592
// Finally, the chain operand
25842593
Ops.push_back(N->getOperand(0));
25852594

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,14 @@ class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
578578
# !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
579579
}
580580

581+
def CTAGroupFlags : Operand<i32> {
582+
let PrintMethod = "printCTAGroup";
583+
}
584+
581585
multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
582586
defvar dims_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "d" # i));
583587
defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
584-
defvar asm_str_default = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
588+
defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
585589
defvar rc = !if(is_shared32, Int32Regs, Int64Regs);
586590

587591
defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
@@ -595,19 +599,22 @@ multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode>
595599
!strconcat(asm_str_default, im2col_asm_str), asm_str_default);
596600

597601
def "" : NVPTXInst<(outs),
598-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag),
602+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
599603
!strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";"), []>,
600604
Requires<[hasPTX<80>, hasSM<90>]>;
601605
def _MC : NVPTXInst<(outs),
602-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc)),
606+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
607+
(ins Int16Regs:$mc, CTAGroupFlags:$cg)),
603608
!strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;"), []>,
604609
Requires<[hasPTX<80>, hasSM<90>]>;
605610
def _CH : NVPTXInst<(outs),
606-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int64Regs:$ch)),
611+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
612+
(ins Int64Regs:$ch, CTAGroupFlags:$cg)),
607613
!strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;"), []>,
608614
Requires<[hasPTX<80>, hasSM<90>]>;
609615
def _MC_CH : NVPTXInst<(outs),
610-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc, Int64Regs:$ch)),
616+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
617+
(ins Int16Regs:$mc, Int64Regs:$ch, CTAGroupFlags:$cg)),
611618
!strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;"), []>,
612619
Requires<[hasPTX<80>, hasSM<90>]>;
613620
}

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119119

120+
// TMA G2S copy with cta_group::1/2 support
121+
bool hasCpAsyncBulkTensorCTAGroupSupport() const {
122+
// TODO: Update/tidy-up after the family-conditional support arrives
123+
return ((FullSmVersion == 1001 || FullSmVersion == 1011) &&
124+
PTXVersion >= 86) ||
125+
(FullSmVersion == 1031 && PTXVersion >= 88);
126+
}
127+
120128
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
121129
// terminates a basic block. Instead, it would assume that control flow
122130
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)