Skip to content

Commit db9a09c

Browse files
committed
hoist more shared logic into the helper function
1 parent c04dcff commit db9a09c

File tree

1 file changed

+132
-167
lines changed

1 file changed

+132
-167
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 132 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include <iterator>
6565
#include <optional>
6666
#include <string>
67+
#include <tuple>
6768
#include <utility>
6869
#include <vector>
6970

@@ -162,14 +163,70 @@ static bool Is16bitsType(MVT VT) {
162163
VT.SimpleTy == MVT::i16);
163164
}
164165

165-
static auto GetUpsizedNumEltsAndEltVT(unsigned OldNumElts, EVT OldEltVT) {
166-
// Number of elements to pack in one word.
167-
unsigned NPerWord = 32 / OldEltVT.getSizeInBits();
168-
// Word-sized vector.
169-
EVT NewEltVT = MVT::getVectorVT(OldEltVT.getSimpleVT(), NPerWord);
170-
// Number of word-sized vectors.
171-
unsigned NewNumElts = OldNumElts / NPerWord;
172-
return std::pair(NewNumElts, NewEltVT);
166+
// When legalizing vector loads/stores, this function is called, which does two things:
167+
// 1. Determines Whether the vector is something we want to custom lower, std::nullopt is returned if we do not want to custom lower it.
168+
// 2. If we do want to handle it, returns three parameters:
169+
// - unsigned int NumElts - The number of elements in the final vector
170+
// - EVT EltVT - The type of the elements in the final vector
171+
// - bool UpsizeElementTypes - Whether or not we are upsizing the elements of the vectors
172+
static std::optional<std::tuple<unsigned int, EVT, bool>> tryGetVectorLoweringParams(EVT ValVT) {
173+
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
174+
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
175+
// vectorized loads/stores with the actual element type for i8/i16 as that
176+
// would require v8/v16 variants that do not exist.
177+
// In order to load/store such vectors efficiently, here in Type Legalization,
178+
// we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
179+
// lower to PTX as vectors of b32.
180+
bool UpsizeElementTypes = false;
181+
182+
if (!ValVT.isVector() || !ValVT.isSimple())
183+
return std::nullopt;
184+
185+
EVT EltVT = ValVT.getVectorElementType();
186+
unsigned NumElts = ValVT.getVectorNumElements();
187+
188+
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
189+
// legal. We can (and should) split that into 2 stores of <2 x double> here
190+
// but I'm leaving that as a TODO for now.
191+
switch (ValVT.getSimpleVT().SimpleTy) {
192+
default:
193+
return std::nullopt;
194+
case MVT::v2i8:
195+
case MVT::v2i16:
196+
case MVT::v2i32:
197+
case MVT::v2i64:
198+
case MVT::v2f16:
199+
case MVT::v2bf16:
200+
case MVT::v2f32:
201+
case MVT::v2f64:
202+
case MVT::v4i8:
203+
case MVT::v4i16:
204+
case MVT::v4i32:
205+
case MVT::v4f16:
206+
case MVT::v4bf16:
207+
case MVT::v4f32:
208+
// This is a "native" vector type
209+
break;
210+
case MVT::v8i8: // <2 x i8x4>
211+
case MVT::v8f16: // <4 x f16x2>
212+
case MVT::v8bf16: // <4 x bf16x2>
213+
case MVT::v8i16: // <4 x i16x2>
214+
case MVT::v16i8: // <4 x i8x4>
215+
// This can be upsized into a "native" vector type
216+
UpsizeElementTypes = true;
217+
break;
218+
}
219+
220+
if (UpsizeElementTypes) {
221+
// Number of elements to pack in one word.
222+
unsigned NPerWord = 32 / EltVT.getSizeInBits();
223+
// Word-sized vector.
224+
EltVT = MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord);
225+
// Number of word-sized vectors.
226+
NumElts = NumElts / NPerWord;
227+
}
228+
229+
return std::tuple(NumElts, EltVT, UpsizeElementTypes);
173230
};
174231

175232
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2819,130 +2876,81 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
28192876
SDLoc DL(N);
28202877
EVT ValVT = Val.getValueType();
28212878

2822-
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
2823-
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
2824-
// vectorized loads/stores with the actual element type for i8/i16 as that
2825-
// would require v8/v16 variants that do not exist.
2826-
// In order to load/store such vectors efficiently, here in Type Legalization,
2827-
// we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
2828-
// lower to PTX as vectors of b32.
2829-
bool UpsizeElementTypes = false;
2830-
2831-
if (ValVT.isVector()) {
2832-
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
2833-
// legal. We can (and should) split that into 2 stores of <2 x double> here
2834-
// but I'm leaving that as a TODO for now.
2835-
if (!ValVT.isSimple())
2836-
return SDValue();
2837-
switch (ValVT.getSimpleVT().SimpleTy) {
2838-
default:
2839-
return SDValue();
2840-
case MVT::v2i8:
2841-
case MVT::v2i16:
2842-
case MVT::v2i32:
2843-
case MVT::v2i64:
2844-
case MVT::v2f16:
2845-
case MVT::v2bf16:
2846-
case MVT::v2f32:
2847-
case MVT::v2f64:
2848-
case MVT::v4i8:
2849-
case MVT::v4i16:
2850-
case MVT::v4i32:
2851-
case MVT::v4f16:
2852-
case MVT::v4bf16:
2853-
case MVT::v4f32:
2854-
// This is a "native" vector type
2855-
break;
2856-
case MVT::v8i8: // <2 x i8x4>
2857-
case MVT::v8f16: // <4 x f16x2>
2858-
case MVT::v8bf16: // <4 x bf16x2>
2859-
case MVT::v8i16: // <4 x i16x2>
2860-
case MVT::v16i8: // <4 x i8x4>
2861-
// This can be upsized into a "native" vector type
2862-
UpsizeElementTypes = true;
2863-
break;
2864-
}
2865-
2866-
MemSDNode *MemSD = cast<MemSDNode>(N);
2867-
const DataLayout &TD = DAG.getDataLayout();
2868-
2869-
Align Alignment = MemSD->getAlign();
2870-
Align PrefAlign =
2871-
TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2872-
if (Alignment < PrefAlign) {
2873-
// This store is not sufficiently aligned, so bail out and let this vector
2874-
// store be scalarized. Note that we may still be able to emit smaller
2875-
// vector stores. For example, if we are storing a <4 x float> with an
2876-
// alignment of 8, this check will fail but the legalizer will try again
2877-
// with 2 x <2 x float>, which will succeed with an alignment of 8.
2878-
return SDValue();
2879-
}
2879+
auto VectorLoweringParams = tryGetVectorLoweringParams(ValVT);
2880+
if (!VectorLoweringParams)
2881+
return SDValue();
2882+
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();
28802883

2881-
unsigned Opcode = 0;
2882-
EVT EltVT = ValVT.getVectorElementType();
2883-
unsigned NumElts = ValVT.getVectorNumElements();
2884+
MemSDNode *MemSD = cast<MemSDNode>(N);
2885+
const DataLayout &TD = DAG.getDataLayout();
28842886

2885-
if (UpsizeElementTypes) {
2886-
std::tie(NumElts, EltVT) = GetUpsizedNumEltsAndEltVT(NumElts, EltVT);
2887-
}
2887+
Align Alignment = MemSD->getAlign();
2888+
Align PrefAlign =
2889+
TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2890+
if (Alignment < PrefAlign) {
2891+
// This store is not sufficiently aligned, so bail out and let this vector
2892+
// store be scalarized. Note that we may still be able to emit smaller
2893+
// vector stores. For example, if we are storing a <4 x float> with an
2894+
// alignment of 8, this check will fail but the legalizer will try again
2895+
// with 2 x <2 x float>, which will succeed with an alignment of 8.
2896+
return SDValue();
2897+
}
28882898

2889-
// Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2890-
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
2891-
// stored type to i16 and propagate the "real" type as the memory type.
2892-
bool NeedExt = false;
2893-
if (EltVT.getSizeInBits() < 16)
2894-
NeedExt = true;
2899+
// Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2900+
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
2901+
// stored type to i16 and propagate the "real" type as the memory type.
2902+
bool NeedExt = false;
2903+
if (EltVT.getSizeInBits() < 16)
2904+
NeedExt = true;
28952905

2896-
switch (NumElts) {
2897-
default:
2898-
return SDValue();
2899-
case 2:
2900-
Opcode = NVPTXISD::StoreV2;
2901-
break;
2902-
case 4:
2903-
Opcode = NVPTXISD::StoreV4;
2904-
break;
2905-
}
2906+
unsigned Opcode = 0;
2907+
switch (NumElts) {
2908+
default:
2909+
return SDValue();
2910+
case 2:
2911+
Opcode = NVPTXISD::StoreV2;
2912+
break;
2913+
case 4:
2914+
Opcode = NVPTXISD::StoreV4;
2915+
break;
2916+
}
29062917

2907-
SmallVector<SDValue, 8> Ops;
2918+
SmallVector<SDValue, 8> Ops;
29082919

2909-
// First is the chain
2910-
Ops.push_back(N->getOperand(0));
2920+
// First is the chain
2921+
Ops.push_back(N->getOperand(0));
29112922

2912-
if (UpsizeElementTypes) {
2913-
// Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2914-
// stored as b32s
2915-
unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2916-
for (unsigned i = 0; i < NumElts; ++i) {
2917-
SmallVector<SDValue, 4> SubVectorElts;
2918-
DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2919-
NumEltsPerSubVector);
2920-
SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2921-
Ops.push_back(SubVector);
2922-
}
2923-
} else {
2924-
// Then the split values
2925-
for (unsigned i = 0; i < NumElts; ++i) {
2926-
SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2927-
DAG.getIntPtrConstant(i, DL));
2928-
if (NeedExt)
2929-
ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2930-
Ops.push_back(ExtVal);
2931-
}
2923+
// Then the split values
2924+
if (UpsizeElementTypes) {
2925+
// Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2926+
// stored as b32s
2927+
unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2928+
for (unsigned i = 0; i < NumElts; ++i) {
2929+
SmallVector<SDValue, 4> SubVectorElts;
2930+
DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2931+
NumEltsPerSubVector);
2932+
SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2933+
Ops.push_back(SubVector);
29322934
}
2935+
} else {
2936+
for (unsigned i = 0; i < NumElts; ++i) {
2937+
SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2938+
DAG.getIntPtrConstant(i, DL));
2939+
if (NeedExt)
2940+
ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2941+
Ops.push_back(ExtVal);
2942+
}
2943+
}
29332944

2934-
// Then any remaining arguments
2935-
Ops.append(N->op_begin() + 2, N->op_end());
2936-
2937-
SDValue NewSt =
2938-
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2939-
MemSD->getMemoryVT(), MemSD->getMemOperand());
2945+
// Then any remaining arguments
2946+
Ops.append(N->op_begin() + 2, N->op_end());
29402947

2941-
// return DCI.CombineTo(N, NewSt, true);
2942-
return NewSt;
2943-
}
2948+
SDValue NewSt =
2949+
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2950+
MemSD->getMemoryVT(), MemSD->getMemOperand());
29442951

2945-
return SDValue();
2952+
// return DCI.CombineTo(N, NewSt, true);
2953+
return NewSt;
29462954
}
29472955

29482956
// st i1 v, addr
@@ -5225,46 +5233,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52255233

52265234
assert(ResVT.isVector() && "Vector load must have vector type");
52275235

5228-
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
5229-
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
5230-
// vectorized loads/stores with the actual element type for i8/i16 as that
5231-
// would require v8/v16 variants that do not exist.
5232-
// In order to load/store such vectors efficiently, here in Type Legalization,
5233-
// we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
5234-
// lower to PTX as vectors of b32.
5235-
bool UpsizeElementTypes = false;
5236-
5237-
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
5238-
// legal. We can (and should) split that into 2 loads of <2 x double> here
5239-
// but I'm leaving that as a TODO for now.
5240-
assert(ResVT.isSimple() && "Can only handle simple types");
5241-
switch (ResVT.getSimpleVT().SimpleTy) {
5242-
default:
5236+
auto VectorLoweringParams = tryGetVectorLoweringParams(ResVT);
5237+
if (!VectorLoweringParams)
52435238
return;
5244-
case MVT::v2i8:
5245-
case MVT::v2i16:
5246-
case MVT::v2i32:
5247-
case MVT::v2i64:
5248-
case MVT::v2f16:
5249-
case MVT::v2f32:
5250-
case MVT::v2f64:
5251-
case MVT::v4i8:
5252-
case MVT::v4i16:
5253-
case MVT::v4i32:
5254-
case MVT::v4f16:
5255-
case MVT::v4bf16:
5256-
case MVT::v4f32:
5257-
// This is a "native" vector type
5258-
break;
5259-
case MVT::v8i8: // <2 x i8x4>
5260-
case MVT::v8f16: // <4 x f16x2>
5261-
case MVT::v8bf16: // <4 x bf16x2>
5262-
case MVT::v8i16: // <4 x i16x2>
5263-
case MVT::v16i8: // <4 x i8x4>
5264-
// This can be upsized into a "native" vector type
5265-
UpsizeElementTypes = true;
5266-
break;
5267-
}
5239+
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();
52685240

52695241
LoadSDNode *LD = cast<LoadSDNode>(N);
52705242

@@ -5281,13 +5253,6 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52815253
return;
52825254
}
52835255

5284-
EVT EltVT = ResVT.getVectorElementType();
5285-
unsigned NumElts = ResVT.getVectorNumElements();
5286-
5287-
if (UpsizeElementTypes) {
5288-
std::tie(NumElts, EltVT) = GetUpsizedNumEltsAndEltVT(NumElts, EltVT);
5289-
}
5290-
52915256
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
52925257
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
52935258
// loaded type to i16 and propagate the "real" type as the memory type.

0 commit comments

Comments
 (0)