Skip to content

Commit 203f2b6

Browse files
committed
[NVPTX] Extend TMA intrinsics with 2-CTA mode
This patch extends the TMA G2S intrinsics with the 2-CTA mode support available from Blackwell onwards. The existing intrinsics are auto-upgraded with a default value of '0' for the `is_2cta_mode` flag operand. lit tests are added for all combinations of the new variant. The generated PTX is verified with a 12.8 ptxas executable. Signed-off-by: Durgadoss R <[email protected]>
1 parent 1540ed5 commit 203f2b6

File tree

9 files changed

+576
-43
lines changed

9 files changed

+576
-43
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ Syntax:
10161016

10171017
.. code-block:: llvm
10181018
1019-
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch)
1019+
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.1d(ptr addrspace(7) %dst, ptr addrspace(3) %bar, ptr %tensor_map, i32 %d0, i16 %mc, i64 %ch, i1 %flag_mc, i1 %flag_ch, i1 %flag_cta_group)
10201020
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.2d(..., i32 %d0, i32 %d1, ...)
10211021
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.3d(..., i32 %d0, i32 %d1, i32 %d2, ...)
10221022
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.tile.4d(..., i32 %d0, i32 %d1, i32 %d2, i32 %d3, ...)
@@ -1034,18 +1034,22 @@ source tensor is preserved at the destination. The dimension of the
10341034
tensor data ranges from 1d to 5d with the coordinates specified
10351035
by the ``i32 %d0 ... i32 %d4`` arguments.
10361036

1037-
* The last two arguments to these intrinsics are boolean flags
1038-
indicating support for cache_hint and/or multicast modifiers.
1039-
These flag arguments must be compile-time constants. The backend
1040-
looks through these flags and lowers the intrinsics appropriately.
1037+
* The last three arguments to these intrinsics are boolean flags
1038+
indicating support for multicast, cache_hint and cta_group::2
1039+
modifiers. These flag arguments must be compile-time constants.
1040+
The backend looks through these flags and lowers the intrinsics
1041+
appropriately.
10411042

1042-
* The Nth argument (denoted by ``i1 flag_ch``) when set, indicates
1043+
* The argument denoted by ``i1 flag_ch`` when set, indicates
10431044
a valid cache_hint (``i64 %ch``) and generates the ``.L2::cache_hint``
10441045
variant of the PTX instruction.
10451046

1046-
* The [N-1]th argument (denoted by ``i1 flag_mc``) when set, indicates
1047-
the presence of a multicast mask (``i16 %mc``) and generates the PTX
1048-
instruction with the ``.multicast::cluster`` modifier.
1047+
* The argument denoted by ``i1 flag_mc`` when set, indicates
1048+
the presence of a multicast mask (``i16 %mc``) and generates
1049+
the PTX instruction with the ``.multicast::cluster`` modifier.
1050+
1051+
* The argument denoted by ``i1 flag_cta_group`` when set, generates
1052+
the ``.cta_group::2`` variant of the PTX instruction.
10491053

10501054
For more information, refer PTX ISA
10511055
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`_.

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,10 +2025,10 @@ foreach dim = 1...5 in {
20252025
[llvm_i16_ty, // cta_mask
20262026
llvm_i64_ty]), // cache_hint
20272027
[llvm_i1_ty, // Flag for cta_mask
2028-
llvm_i1_ty], // Flag for cache_hint
2028+
llvm_i1_ty, // Flag for cache_hint
2029+
llvm_i1_ty], // Flag for is_2cta_mode
20292030
[IntrConvergent,
2030-
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
2031-
NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>]>;
2031+
WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>;
20322032

20332033
def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
20342034
DefaultAttrsIntrinsicFlags<[],

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
939939
return false; // No other 'arm.*', 'aarch64.*'.
940940
}
941941

942+
static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics(Function *F,
943+
StringRef Name) {
944+
if (Name.consume_front("cp.async.bulk.tensor.g2s.")) {
945+
Intrinsic::ID ID =
946+
StringSwitch<Intrinsic::ID>(Name)
947+
.Case("im2col.3d",
948+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
949+
.Case("im2col.4d",
950+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
951+
.Case("im2col.5d",
952+
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
953+
.Case("tile.1d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
954+
.Case("tile.2d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
955+
.Case("tile.3d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
956+
.Case("tile.4d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
957+
.Case("tile.5d", Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
958+
.Default(Intrinsic::not_intrinsic);
959+
960+
if (ID == Intrinsic::not_intrinsic)
961+
return ID;
962+
963+
// These intrinsics may need upgrade for two reasons:
964+
// (1) When the address-space of the first argument is shared[AS=3]
965+
// (and we upgrade it to use shared_cluster address-space[AS=7])
966+
if (F->getArg(0)->getType()->getPointerAddressSpace() ==
967+
NVPTXAS::ADDRESS_SPACE_SHARED)
968+
return ID;
969+
970+
// (2) When there are only two boolean flag arguments at the end:
971+
//
972+
// The last three parameters of the older version of these
973+
// intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
974+
//
975+
// The newer version has three boolean flags at the end:
976+
// arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i1 cta_group_flag
977+
//
978+
// So, when the type of the [N-3]rd argument is "not i1", then
979+
// it is the older version and we need to upgrade.
980+
size_t FlagStartIndex = F->getFunctionType()->getNumParams() - 3;
981+
Type *ArgType = F->getFunctionType()->getParamType(FlagStartIndex);
982+
if (!ArgType->isIntegerTy(1))
983+
return ID;
984+
}
985+
986+
return Intrinsic::not_intrinsic;
987+
}
988+
942989
static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
943990
StringRef Name) {
944991
if (Name.consume_front("mapa.shared.cluster"))
@@ -953,22 +1000,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
9531000
Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
9541001
.Case("shared.cta.to.cluster",
9551002
Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
956-
.Case("tensor.g2s.im2col.3d",
957-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
958-
.Case("tensor.g2s.im2col.4d",
959-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
960-
.Case("tensor.g2s.im2col.5d",
961-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
962-
.Case("tensor.g2s.tile.1d",
963-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
964-
.Case("tensor.g2s.tile.2d",
965-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
966-
.Case("tensor.g2s.tile.3d",
967-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
968-
.Case("tensor.g2s.tile.4d",
969-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
970-
.Case("tensor.g2s.tile.5d",
971-
Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
9721003
.Default(Intrinsic::not_intrinsic);
9731004

9741005
if (ID != Intrinsic::not_intrinsic)
@@ -1334,6 +1365,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
13341365
return true;
13351366
}
13361367

1368+
// Upgrade TMA copy G2S Intrinsics
1369+
IID = shouldUpgradeNVPTXTMAG2SIntrinsics(F, Name);
1370+
if (IID != Intrinsic::not_intrinsic) {
1371+
rename(F);
1372+
NewFn = Intrinsic::getOrInsertDeclaration(F->getParent(), IID);
1373+
return true;
1374+
}
1375+
13371376
// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
13381377
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
13391378
//
@@ -4813,7 +4852,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48134852
return;
48144853
}
48154854
case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
4816-
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
4855+
case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
4856+
// Create a new call with the correct address space.
4857+
SmallVector<Value *, 4> Args(CI->args());
4858+
Args[0] = Builder.CreateAddrSpaceCast(
4859+
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4860+
4861+
NewCall = Builder.CreateCall(NewFn, Args);
4862+
NewCall->takeName(CI);
4863+
CI->replaceAllUsesWith(NewCall);
4864+
CI->eraseFromParent();
4865+
return;
4866+
}
48174867
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
48184868
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
48194869
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
@@ -4822,10 +4872,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48224872
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
48234873
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
48244874
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
4825-
// Create a new call with the correct address space.
4826-
SmallVector<Value *, 4> Args(CI->args());
4827-
Args[0] = Builder.CreateAddrSpaceCast(
4828-
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4875+
SmallVector<Value *, 16> Args(CI->args());
4876+
4877+
// Create AddrSpaceCast to shared_cluster if needed.
4878+
// This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
4879+
unsigned AS = CI->getArgOperand(0)->getType()->getPointerAddressSpace();
4880+
if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
4881+
Args[0] = Builder.CreateAddrSpaceCast(
4882+
Args[0], Builder.getPtrTy(NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4883+
4884+
// Attach the flag argument for cta_group, with a
4885+
// default value of 0. This handles case (2) in
4886+
// shouldUpgradeNVPTXTMAG2SIntrinsics().
4887+
size_t NumArgs = CI->arg_size();
4888+
Value *FlagArg = CI->getArgOperand(NumArgs - 3);
4889+
if (!FlagArg->getType()->isIntegerTy(1))
4890+
Args.push_back(ConstantInt::get(Builder.getInt1Ty(), 0));
48294891

48304892
NewCall = Builder.CreateCall(NewFn, Args);
48314893
NewCall->takeName(CI);

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,10 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
437437
llvm_unreachable(
438438
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
439439
}
440+
441+
void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
442+
raw_ostream &O) {
443+
const MCOperand &MO = MI->getOperand(OpNum);
444+
if (MO.getImm())
445+
O << ".cta_group::2";
446+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
5151
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
5252
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
5353
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
54+
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
5455
};
5556

5657
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,19 +2556,25 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
25562556
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
25572557
// {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
25582558
// multicast, cache_hint,
2559-
// multicast_flag, cache_hint_flag}
2559+
// multicast_flag, cache_hint_flag, 2cta_mode_flag}
25602560
// NumOperands = {Chain, IID} + {Actual intrinsic args}
2561-
// = {2} + {7 + dims + im2col_offsets}
2561+
// = {2} + {8 + dims + im2col_offsets}
25622562
size_t NumOps = N->getNumOperands();
25632563
size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
2564-
: (NumOps - 9);
2564+
: (NumOps - 10);
25652565
// Offsets is always 'NumDims - 2' and only for im2col mode
25662566
size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
2567-
bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
2568-
bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
2567+
bool Is2CTAMode = N->getConstantOperandVal(NumOps - 1) == 1;
2568+
bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1;
2569+
bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1;
25692570
size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
25702571
size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID
25712572

2573+
if (Is2CTAMode && !Subtarget->hasCpAsyncBulkTensor2CTASupport())
2574+
report_fatal_error(
2575+
formatv("CpAsyncBulkTensorG2S 2CTA mode is not supported on sm_{}",
2576+
Subtarget->getSmVersion()));
2577+
25722578
SDLoc DL(N);
25732579
SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
25742580

@@ -2580,6 +2586,9 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
25802586
if (IsCacheHint)
25812587
Ops.push_back(N->getOperand(MultiCastIdx + 1));
25822588

2589+
// Flag for 2-CTA mode
2590+
Ops.push_back(CurDAG->getTargetConstant(Is2CTAMode, DL, MVT::i1));
2591+
25832592
// Finally, the chain operand
25842593
Ops.push_back(N->getOperand(0));
25852594

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -593,10 +593,14 @@ class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
593593
# !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
594594
}
595595

596+
def CTAGroupFlags : Operand<i1> {
597+
let PrintMethod = "printCTAGroup";
598+
}
599+
596600
multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
597601
defvar dims_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "d" # i));
598602
defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
599-
defvar asm_str_default = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
603+
defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
600604
defvar rc = !if(is_shared32, Int32Regs, Int64Regs);
601605

602606
defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
@@ -610,19 +614,22 @@ multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode>
610614
!strconcat(asm_str_default, im2col_asm_str), asm_str_default);
611615

612616
def NAME: NVPTXInst<(outs),
613-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag),
617+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
614618
!strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";"), []>,
615619
Requires<[hasPTX<80>, hasSM<90>]>;
616620
def NAME # _MC: NVPTXInst<(outs),
617-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc)),
621+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
622+
(ins Int16Regs:$mc, CTAGroupFlags:$cg)),
618623
!strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;"), []>,
619624
Requires<[hasPTX<80>, hasSM<90>]>;
620625
def NAME # _CH: NVPTXInst<(outs),
621-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int64Regs:$ch)),
626+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
627+
(ins Int64Regs:$ch, CTAGroupFlags:$cg)),
622628
!strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;"), []>,
623629
Requires<[hasPTX<80>, hasSM<90>]>;
624630
def NAME # _MC_CH: NVPTXInst<(outs),
625-
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag, (ins Int16Regs:$mc, Int64Regs:$ch)),
631+
!con((ins rc:$dst, rc:$mbar, Int64Regs:$tmap), dims_dag, im2col_dag,
632+
(ins Int16Regs:$mc, Int64Regs:$ch, CTAGroupFlags:$cg)),
626633
!strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;"), []>,
627634
Requires<[hasPTX<80>, hasSM<90>]>;
628635
}

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119119

120+
// TMA G2S copy 2cta mode support
121+
bool hasCpAsyncBulkTensor2CTASupport() const {
122+
// TODO: Update/tidy-up after the family-conditional support arrives
123+
return ((FullSmVersion == 1001 || FullSmVersion == 1011) &&
124+
PTXVersion >= 86) ||
125+
(FullSmVersion == 1031 && PTXVersion >= 88);
126+
}
127+
120128
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
121129
// terminates a basic block. Instead, it would assume that control flow
122130
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)