Skip to content

Commit 2a48826

Browse files
committed
Refactor helper function
1 parent 9cfbdad commit 2a48826

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,7 @@ static bool Is16bitsType(MVT VT) {
173173
// - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
174174
// the vector
175175
static std::optional<std::tuple<unsigned int, EVT, bool>>
176-
tryGetVectorLoweringParams(EVT ValVT) {
177-
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
178-
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
179-
// vectorized loads/stores with the actual element type for i8/i16 as that
180-
// would require v8/v16 variants that do not exist.
181-
// In order to load/store such vectors efficiently, here in Type Legalization,
182-
// we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
183-
// lower to PTX as vectors of b32.
184-
bool UpsizeElementTypes = false;
185-
176+
getVectorLoweringShape(EVT ValVT) {
186177
if (!ValVT.isVector() || !ValVT.isSimple())
187178
return std::nullopt;
188179

@@ -210,27 +201,29 @@ tryGetVectorLoweringParams(EVT ValVT) {
210201
case MVT::v4bf16:
211202
case MVT::v4f32:
212203
// This is a "native" vector type
213-
break;
204+
return std::tuple(NumElts, EltVT, /* UpsizeElementTypes = */ false);
214205
case MVT::v8i8: // <2 x i8x4>
215206
case MVT::v8f16: // <4 x f16x2>
216207
case MVT::v8bf16: // <4 x bf16x2>
217208
case MVT::v8i16: // <4 x i16x2>
218209
case MVT::v16i8: // <4 x i8x4>
219-
// This can be upsized into a "native" vector type
220-
UpsizeElementTypes = true;
221-
break;
222-
}
210+
// This can be upsized into a "native" vector type.
211+
// Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
212+
// total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
213+
// vectorized loads/stores with the actual element type for i8/i16 as that
214+
// would require v8/v16 variants that do not exist.
215+
// In order to load/store such vectors efficiently, here in Type
216+
// Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
217+
// Later, we will lower to PTX as vectors of b32.
223218

224-
if (UpsizeElementTypes) {
225219
// Number of elements to pack in one word.
226220
unsigned NPerWord = 32 / EltVT.getSizeInBits();
227-
// Word-sized vector.
228-
EltVT = MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord);
229-
// Number of word-sized vectors.
230-
NumElts = NumElts / NPerWord;
221+
return std::tuple(NumElts / NPerWord,
222+
MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord),
223+
/* UpsizeElementTypes = */ true);
231224
}
232225

233-
return std::tuple(NumElts, EltVT, UpsizeElementTypes);
226+
llvm_unreachable("All cases should return.");
234227
};
235228

236229
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2880,7 +2873,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
28802873
SDLoc DL(N);
28812874
EVT ValVT = Val.getValueType();
28822875

2883-
auto VectorLoweringParams = tryGetVectorLoweringParams(ValVT);
2876+
auto VectorLoweringParams = getVectorLoweringShape(ValVT);
28842877
if (!VectorLoweringParams)
28852878
return SDValue();
28862879
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();
@@ -5236,7 +5229,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52365229

52375230
assert(ResVT.isVector() && "Vector load must have vector type");
52385231

5239-
auto VectorLoweringParams = tryGetVectorLoweringParams(ResVT);
5232+
auto VectorLoweringParams = getVectorLoweringShape(ResVT);
52405233
if (!VectorLoweringParams)
52415234
return;
52425235
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();

0 commit comments

Comments
 (0)