Skip to content

Commit ffb6151

Browse files
committed
[NVPTX] Add TMA bulk tensor copy intrinsics
This patch adds NVVM intrinsics and NVPTX codeGen for: * cp.async.bulk.tensor.S2G.1D -> 5D variants, with optional support for cache_hints. * cp.async.bulk.tensor.G2S.1D -> 5D variants, with optional support for multicast and cache_hints. Moreover, the 3D->5D variants also have support for an 'im2col' mode, with its own set of offsets. * The first argument of these intrinsics is an immediate i32-flag. The bit-fields of the flag control enabling optional features like multicast, cache_hints and im2col offsets when applicable. The backend looks through these flag-bits and lowers to the appropriate PTX instruction. * Lit tests are added for all combinations of these intrinsics in cp-async-bulk-tensor-g2s/s2g.ll. * The generated PTX is verified with a 12.3 ptxas executable. TODO: Update documentation for these intrinsics in NVPTX guide. Signed-off-by: Durgadoss R <[email protected]>
1 parent 177ce19 commit ffb6151

File tree

7 files changed

+729
-0
lines changed

7 files changed

+729
-0
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,22 @@ class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
552552
[OpType, llvm_i32_ty, llvm_i32_ty]);
553553
}
554554

555+
class NVVM_INTRINSIC_RECORD<string intr> {
556+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
557+
}
558+
559+
class NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<int dim> {
560+
string intr = "llvm.nvvm.cp.async.bulk.tensor.gmem.to.smem"
561+
# "." # dim # "d";
562+
string record = NVVM_INTRINSIC_RECORD<intr>.record;
563+
}
564+
565+
class NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<int dim> {
566+
string intr = "llvm.nvvm.cp.async.bulk.tensor.smem.to.gmem"
567+
# "." # dim # "d";
568+
string record = NVVM_INTRINSIC_RECORD<intr>.record;
569+
}
570+
555571
let TargetPrefix = "nvvm" in {
556572
def int_nvvm_prmt : ClangBuiltin<"__nvvm_prmt">,
557573
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@@ -4828,4 +4844,42 @@ def int_nvvm_setmaxnreg_dec_sync_aligned_u32
48284844
def int_nvvm_exit : ClangBuiltin<"__nvvm_exit">,
48294845
Intrinsic<[], [], [IntrConvergent, IntrInaccessibleMemOnly, IntrNoReturn]>;
48304846

4847+
// -------- llvm.nvvm.cp.async.bulk.tensor.gmem.to.smem
4848+
class NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_INTR<int dim> :
4849+
DefaultAttrsIntrinsic<[],
4850+
!listconcat(
4851+
// flags, dst_smem_ptr, barrier_ptr, tensor_map_ptr
4852+
[llvm_i32_ty, llvm_shared_ptr_ty, llvm_shared_ptr_ty, llvm_ptr_ty],
4853+
!listsplat(llvm_i32_ty, dim), // tensor_dims
4854+
!if(!ge(dim, 3), !listsplat(llvm_i16_ty, !add(dim, -2)), []), // im2col
4855+
[llvm_i16_ty, llvm_i64_ty]), // cta_mask, cache_policy
4856+
[IntrConvergent, IntrArgMemOnly, ImmArg<ArgIndex<0>>,
4857+
WriteOnly<ArgIndex<1>>, ReadOnly<ArgIndex<3>>,
4858+
NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>,
4859+
NoCapture<ArgIndex<3>>],
4860+
NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<dim>.intr>;
4861+
4862+
foreach dim = [1, 2, 3, 4, 5] in {
4863+
def NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<dim>.record :
4864+
NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_INTR<dim>;
4865+
}
4866+
4867+
// -------- llvm.nvvm.cp.async.bulk.tensor.smem.to.gmem
4868+
class NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_INTR<int dim> :
4869+
DefaultAttrsIntrinsic<[],
4870+
!listconcat(
4871+
// flags, src_smem_ptr, tensor_map_ptr
4872+
[llvm_i32_ty, llvm_shared_ptr_ty, llvm_ptr_ty],
4873+
!listsplat(llvm_i32_ty, dim), // tensor_dims
4874+
[llvm_i64_ty]), // cache_policy
4875+
[IntrConvergent, IntrArgMemOnly, ImmArg<ArgIndex<0>>,
4876+
ReadOnly<ArgIndex<1>>, WriteOnly<ArgIndex<2>>,
4877+
NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>],
4878+
NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<dim>.intr>;
4879+
4880+
foreach dim = [1, 2, 3, 4, 5] in {
4881+
def NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<dim>.record :
4882+
NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_INTR<dim>;
4883+
}
4884+
48314885
} // let TargetPrefix = "nvvm"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// This file contains the definitions of the enumerations and flags
11+
/// associated with NVVM Intrinsics.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
16+
#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
17+
18+
#include <stdint.h>
19+
20+
namespace llvm {
21+
namespace nvvm {
22+
23+
enum class CpAsyncBulkTensorLoadMode {
24+
TILE = 0,
25+
IM2COL = 1,
26+
};
27+
28+
typedef union {
29+
int V;
30+
struct {
31+
unsigned CacheHint : 1;
32+
unsigned MultiCast : 1;
33+
unsigned LoadMode : 3; // CpAsyncBulkTensorLoadMode
34+
unsigned reserved : 27;
35+
} U;
36+
} CpAsyncBulkTensorFlags;
37+
38+
} // namespace nvvm
39+
} // namespace llvm
40+
#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/IR/GlobalValue.h"
1919
#include "llvm/IR/Instructions.h"
2020
#include "llvm/IR/IntrinsicsNVPTX.h"
21+
#include "llvm/IR/NVVMIntrinsicFlags.h"
2122
#include "llvm/Support/AtomicOrdering.h"
2223
#include "llvm/Support/CommandLine.h"
2324
#include "llvm/Support/Debug.h"
@@ -160,6 +161,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
160161
if (tryIntrinsicChain(N))
161162
return;
162163
break;
164+
case ISD::INTRINSIC_VOID:
165+
if (tryIntrinsicVoid(N))
166+
return;
167+
break;
163168
case NVPTXISD::Tex1DFloatS32:
164169
case NVPTXISD::Tex1DFloatFloat:
165170
case NVPTXISD::Tex1DFloatFloatLevel:
@@ -4091,3 +4096,246 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
40914096
}
40924097
}
40934098
}
4099+
4100+
static size_t GetCpAsyncBulkTensorDimFromIntrinsic(unsigned IID) {
4101+
switch (IID) {
4102+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
4103+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
4104+
return 1;
4105+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
4106+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
4107+
return 2;
4108+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
4109+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
4110+
return 3;
4111+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
4112+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
4113+
return 4;
4114+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
4115+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
4116+
return 5;
4117+
default:
4118+
llvm_unreachable(
4119+
"Invalid Tensor dim in nvvm_cp_async_bulk_tensor intrinsic");
4120+
}
4121+
}
4122+
4123+
#define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, suffix) \
4124+
if (IsShared32) { \
4125+
return NVPTX:: \
4126+
CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix; \
4127+
} else { \
4128+
return NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix; \
4129+
}
4130+
4131+
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode) \
4132+
do { \
4133+
if (IsCacheHint) { \
4134+
CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, _CH); \
4135+
} else { \
4136+
CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, ); \
4137+
} \
4138+
} while (0)
4139+
4140+
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode) \
4141+
do { \
4142+
if (IsMultiCast && IsCacheHint) { \
4143+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC_CH); \
4144+
} else if (IsCacheHint) { \
4145+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _CH); \
4146+
} else if (IsMultiCast) { \
4147+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC); \
4148+
} else { \
4149+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, ); \
4150+
} \
4151+
} while (0)
4152+
4153+
static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
4154+
bool IsCacheHint, bool IsIm2Col) {
4155+
if (IsIm2Col) {
4156+
switch (Dim) {
4157+
case 3:
4158+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
4159+
case 4:
4160+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
4161+
case 5:
4162+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
4163+
default:
4164+
llvm_unreachable("Invalid Dimension in im2col mode for "
4165+
"GetCpAsyncBulkTensorS2GOpcode.");
4166+
}
4167+
} else {
4168+
switch (Dim) {
4169+
case 1:
4170+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
4171+
case 2:
4172+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
4173+
case 3:
4174+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
4175+
case 4:
4176+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
4177+
case 5:
4178+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
4179+
default:
4180+
llvm_unreachable(
4181+
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
4182+
}
4183+
}
4184+
}
4185+
4186+
static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
4187+
bool IsMultiCast,
4188+
bool IsCacheHint, bool IsIm2Col) {
4189+
if (IsIm2Col) {
4190+
switch (Dim) {
4191+
case 3:
4192+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL);
4193+
case 4:
4194+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL);
4195+
case 5:
4196+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL);
4197+
default:
4198+
llvm_unreachable("Invalid Dimension in im2col mode for "
4199+
"GetCpAsyncBulkTensorG2SOpcode.");
4200+
}
4201+
} else {
4202+
switch (Dim) {
4203+
case 1:
4204+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE);
4205+
case 2:
4206+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE);
4207+
case 3:
4208+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE);
4209+
case 4:
4210+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE);
4211+
case 5:
4212+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE);
4213+
default:
4214+
llvm_unreachable(
4215+
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
4216+
}
4217+
}
4218+
}
4219+
4220+
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2G(SDNode *N) {
4221+
unsigned int SharedPointerSize =
4222+
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
4223+
bool IsShared32 = (SharedPointerSize == 32);
4224+
4225+
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
4226+
size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
4227+
4228+
ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
4229+
nvvm::CpAsyncBulkTensorFlags Flags;
4230+
Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
4231+
bool IsCacheHint = Flags.U.CacheHint == 1;
4232+
bool IsIm2Col = Flags.U.LoadMode == 1;
4233+
4234+
SDLoc DL(N);
4235+
// List of operands that are common to both variants
4236+
SmallVector<SDValue, 4> Ops{
4237+
N->getOperand(3), // Src pointer in smem
4238+
N->getOperand(4), // Dst tensor_map pointer in gmem
4239+
};
4240+
4241+
// Tensor Dims from [1-5] followed by the cache-hint operand
4242+
size_t TensorDimsStartIndex = 5;
4243+
size_t CacheHintIndex = TensorDimsStartIndex + NumDims;
4244+
for (size_t i = 0; i < NumDims; i++)
4245+
Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
4246+
4247+
// Push the cache-hint operand, if available
4248+
if (IsCacheHint)
4249+
Ops.push_back(N->getOperand(CacheHintIndex));
4250+
4251+
// Finally, the chain operand
4252+
Ops.push_back(N->getOperand(0));
4253+
4254+
unsigned Opcode =
4255+
GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
4256+
4257+
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
4258+
}
4259+
4260+
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2S(SDNode *N) {
4261+
unsigned int SharedPointerSize =
4262+
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
4263+
bool IsShared32 = (SharedPointerSize == 32);
4264+
4265+
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
4266+
size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
4267+
4268+
ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
4269+
nvvm::CpAsyncBulkTensorFlags Flags;
4270+
Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
4271+
bool IsCacheHint = Flags.U.CacheHint == 1;
4272+
bool IsMultiCast = Flags.U.MultiCast == 1;
4273+
bool IsIm2Col = Flags.U.LoadMode == 1;
4274+
4275+
if (IsIm2Col && NumDims < 3)
4276+
report_fatal_error("NumDims should be at least 3 for Im2Col mode");
4277+
4278+
SDLoc DL(N);
4279+
// List of operands that are common to both tile and im2col variants
4280+
SmallVector<SDValue, 4> Ops{
4281+
N->getOperand(3), // Dst pointer in smem
4282+
N->getOperand(4), // Mbarrier pointer in smem
4283+
N->getOperand(5), // Src pointer (i.e. tensor_map) in gmem
4284+
};
4285+
4286+
// Tensor Dims from [1-5]
4287+
size_t TensorDimsStartIndex = 6;
4288+
for (size_t i = 0; i < NumDims; i++)
4289+
Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
4290+
4291+
// Im2Col co-ordinates:
4292+
// These are always present in the input arguments for TensorDims{3,4,5}.
4293+
// Number of values is (NumDims - 2).
4294+
size_t Im2ColStartIndex = TensorDimsStartIndex + NumDims;
4295+
size_t NumDimsIm2Col = (NumDims > 2) ? (NumDims - 2) : 0;
4296+
size_t Im2ColEndIndex = Im2ColStartIndex + NumDimsIm2Col;
4297+
// ...However, passed down to the actual NVPTX only when
4298+
// this mode is enabled.
4299+
if (IsIm2Col) {
4300+
for (size_t i = 0; i < NumDimsIm2Col; i++)
4301+
Ops.push_back(N->getOperand(Im2ColStartIndex + i));
4302+
}
4303+
4304+
// Push MultiCast operand, if available
4305+
if (IsMultiCast)
4306+
Ops.push_back(N->getOperand(Im2ColEndIndex));
4307+
4308+
// Push CacheHint operand, if available
4309+
if (IsCacheHint)
4310+
Ops.push_back(N->getOperand(Im2ColEndIndex + 1));
4311+
4312+
// Finally, the chain operand
4313+
Ops.push_back(N->getOperand(0));
4314+
4315+
unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(
4316+
NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col);
4317+
4318+
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
4319+
}
4320+
4321+
bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
4322+
unsigned IID = N->getConstantOperandVal(1);
4323+
switch (IID) {
4324+
default:
4325+
return false;
4326+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
4327+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
4328+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
4329+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
4330+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
4331+
SelectCpAsyncBulkTensorS2G(N);
4332+
return true;
4333+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
4334+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
4335+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
4336+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
4337+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
4338+
SelectCpAsyncBulkTensorG2S(N);
4339+
return true;
4340+
}
4341+
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
5757
void Select(SDNode *N) override;
5858
bool tryIntrinsicNoChain(SDNode *N);
5959
bool tryIntrinsicChain(SDNode *N);
60+
bool tryIntrinsicVoid(SDNode *N);
6061
void SelectTexSurfHandle(SDNode *N);
6162
bool tryLoad(SDNode *N);
6263
bool tryLoadVector(SDNode *N);
@@ -76,6 +77,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7677
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
7778
void SelectV2I64toI128(SDNode *N);
7879
void SelectI128toV2I64(SDNode *N);
80+
void SelectCpAsyncBulkTensorS2G(SDNode *N);
81+
void SelectCpAsyncBulkTensorG2S(SDNode *N);
7982
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
8083
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
8184
}

0 commit comments

Comments
 (0)