Skip to content

Commit edc9c7f

Browse files
committed
Remove upsized boolean and derive it in the caller
1 parent 2a48826 commit edc9c7f

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
#include <iterator>
6565
#include <optional>
6666
#include <string>
67-
#include <tuple>
6867
#include <utility>
6968
#include <vector>
7069

@@ -167,23 +166,21 @@ static bool Is16bitsType(MVT VT) {
167166
// things:
168167
// 1. Determines Whether the vector is something we want to custom lower,
169168
// std::nullopt is returned if we do not want to custom lower it.
170-
// 2. If we do want to handle it, returns three parameters:
169+
// 2. If we do want to handle it, returns two parameters:
171170
// - unsigned int NumElts - The number of elements in the final vector
172171
// - EVT EltVT - The type of the elements in the final vector
173-
// - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
174-
// the vector
175-
static std::optional<std::tuple<unsigned int, EVT, bool>>
176-
getVectorLoweringShape(EVT ValVT) {
177-
if (!ValVT.isVector() || !ValVT.isSimple())
172+
static std::optional<std::pair<unsigned int, EVT>>
173+
getVectorLoweringShape(EVT VectorVT) {
174+
if (!VectorVT.isVector() || !VectorVT.isSimple())
178175
return std::nullopt;
179176

180-
EVT EltVT = ValVT.getVectorElementType();
181-
unsigned NumElts = ValVT.getVectorNumElements();
177+
EVT EltVT = VectorVT.getVectorElementType();
178+
unsigned NumElts = VectorVT.getVectorNumElements();
182179

183180
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
184181
// legal. We can (and should) split that into 2 stores of <2 x double> here
185182
// but I'm leaving that as a TODO for now.
186-
switch (ValVT.getSimpleVT().SimpleTy) {
183+
switch (VectorVT.getSimpleVT().SimpleTy) {
187184
default:
188185
return std::nullopt;
189186
case MVT::v2i8:
@@ -201,7 +198,7 @@ getVectorLoweringShape(EVT ValVT) {
201198
case MVT::v4bf16:
202199
case MVT::v4f32:
203200
// This is a "native" vector type
204-
return std::tuple(NumElts, EltVT, /* UpsizeElementTypes = */ false);
201+
return std::pair(NumElts, EltVT);
205202
case MVT::v8i8: // <2 x i8x4>
206203
case MVT::v8f16: // <4 x f16x2>
207204
case MVT::v8bf16: // <4 x bf16x2>
@@ -218,12 +215,12 @@ getVectorLoweringShape(EVT ValVT) {
218215

219216
// Number of elements to pack in one word.
220217
unsigned NPerWord = 32 / EltVT.getSizeInBits();
221-
return std::tuple(NumElts / NPerWord,
222-
MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord),
223-
/* UpsizeElementTypes = */ true);
218+
219+
return std::pair(NumElts / NPerWord,
220+
MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
224221
}
225222

226-
llvm_unreachable("All cases should return.");
223+
llvm_unreachable("All cases in switch should return.");
227224
};
228225

229226
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2873,10 +2870,10 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
28732870
SDLoc DL(N);
28742871
EVT ValVT = Val.getValueType();
28752872

2876-
auto VectorLoweringParams = getVectorLoweringShape(ValVT);
2877-
if (!VectorLoweringParams)
2873+
auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
2874+
if (!NumEltsAndEltVT)
28782875
return SDValue();
2879-
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();
2876+
auto [NumElts, EltVT] = NumEltsAndEltVT.value();
28802877

28812878
MemSDNode *MemSD = cast<MemSDNode>(N);
28822879
const DataLayout &TD = DAG.getDataLayout();
@@ -2917,7 +2914,10 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
29172914
Ops.push_back(N->getOperand(0));
29182915

29192916
// Then the split values
2920-
if (UpsizeElementTypes) {
2917+
if (ValVT.getVectorNumElements() > NumElts) {
2918+
// If the number of elements has changed, getVectorLoweringShape has upsized
2919+
// the element types
2920+
assert((Isv2x16VT(EltVT) || EltVT == MVT::v4i8) && "Unexpected upsized type.");
29212921
// Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
29222922
// stored as b32s
29232923
unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
@@ -5229,10 +5229,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52295229

52305230
assert(ResVT.isVector() && "Vector load must have vector type");
52315231

5232-
auto VectorLoweringParams = getVectorLoweringShape(ResVT);
5233-
if (!VectorLoweringParams)
5232+
auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
5233+
if (!NumEltsAndEltVT)
52345234
return;
5235-
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value();
5235+
auto [NumElts, EltVT] = NumEltsAndEltVT.value();
52365236

52375237
LoadSDNode *LD = cast<LoadSDNode>(N);
52385238

@@ -5288,7 +5288,11 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52885288
LD->getMemOperand());
52895289

52905290
SmallVector<SDValue> ScalarRes;
5291-
if (UpsizeElementTypes) {
5291+
if (ResVT.getVectorNumElements() > NumElts) {
5292+
// If the number of elements has changed, getVectorLoweringShape has upsized
5293+
// the element types
5294+
assert((Isv2x16VT(EltVT) || EltVT == MVT::v4i8) &&
5295+
"Unexpected upsized type.");
52925296
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
52935297
// into individual elements.
52945298
for (unsigned i = 0; i < NumElts; ++i) {

0 commit comments

Comments
 (0)