64
64
#include < iterator>
65
65
#include < optional>
66
66
#include < string>
67
- #include < tuple>
68
67
#include < utility>
69
68
#include < vector>
70
69
@@ -167,23 +166,21 @@ static bool Is16bitsType(MVT VT) {
167
166
// things:
168
167
// 1. Determines Whether the vector is something we want to custom lower,
169
168
// std::nullopt is returned if we do not want to custom lower it.
170
- // 2. If we do want to handle it, returns three parameters:
169
+ // 2. If we do want to handle it, returns two parameters:
171
170
// - unsigned int NumElts - The number of elements in the final vector
172
171
// - EVT EltVT - The type of the elements in the final vector
173
- // - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
174
- // the vector
175
- static std::optional<std::tuple<unsigned int , EVT, bool >>
176
- getVectorLoweringShape (EVT ValVT) {
177
- if (!ValVT.isVector () || !ValVT.isSimple ())
172
+ static std::optional<std::pair<unsigned int , EVT>>
173
+ getVectorLoweringShape (EVT VectorVT) {
174
+ if (!VectorVT.isVector () || !VectorVT.isSimple ())
178
175
return std::nullopt;
179
176
180
- EVT EltVT = ValVT .getVectorElementType ();
181
- unsigned NumElts = ValVT .getVectorNumElements ();
177
+ EVT EltVT = VectorVT .getVectorElementType ();
178
+ unsigned NumElts = VectorVT .getVectorNumElements ();
182
179
183
180
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
184
181
// legal. We can (and should) split that into 2 stores of <2 x double> here
185
182
// but I'm leaving that as a TODO for now.
186
- switch (ValVT .getSimpleVT ().SimpleTy ) {
183
+ switch (VectorVT .getSimpleVT ().SimpleTy ) {
187
184
default :
188
185
return std::nullopt;
189
186
case MVT::v2i8:
@@ -201,7 +198,7 @@ getVectorLoweringShape(EVT ValVT) {
201
198
case MVT::v4bf16:
202
199
case MVT::v4f32:
203
200
// This is a "native" vector type
204
- return std::tuple (NumElts, EltVT, /* UpsizeElementTypes = */ false );
201
+ return std::pair (NumElts, EltVT);
205
202
case MVT::v8i8: // <2 x i8x4>
206
203
case MVT::v8f16: // <4 x f16x2>
207
204
case MVT::v8bf16: // <4 x bf16x2>
@@ -218,12 +215,12 @@ getVectorLoweringShape(EVT ValVT) {
218
215
219
216
// Number of elements to pack in one word.
220
217
unsigned NPerWord = 32 / EltVT.getSizeInBits ();
221
- return std::tuple (NumElts / NPerWord,
222
- MVT::getVectorVT (EltVT. getSimpleVT (), NPerWord) ,
223
- /* UpsizeElementTypes = */ true );
218
+
219
+ return std::pair (NumElts / NPerWord,
220
+ MVT::getVectorVT (EltVT. getSimpleVT (), NPerWord) );
224
221
}
225
222
226
- llvm_unreachable (" All cases should return." );
223
+ llvm_unreachable (" All cases in switch should return." );
227
224
};
228
225
229
226
// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2873,10 +2870,10 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2873
2870
SDLoc DL (N);
2874
2871
EVT ValVT = Val.getValueType ();
2875
2872
2876
- auto VectorLoweringParams = getVectorLoweringShape (ValVT);
2877
- if (!VectorLoweringParams )
2873
+ auto NumEltsAndEltVT = getVectorLoweringShape (ValVT);
2874
+ if (!NumEltsAndEltVT )
2878
2875
return SDValue ();
2879
- auto [NumElts, EltVT, UpsizeElementTypes ] = VectorLoweringParams .value ();
2876
+ auto [NumElts, EltVT] = NumEltsAndEltVT .value ();
2880
2877
2881
2878
MemSDNode *MemSD = cast<MemSDNode>(N);
2882
2879
const DataLayout &TD = DAG.getDataLayout ();
@@ -2917,7 +2914,10 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2917
2914
Ops.push_back (N->getOperand (0 ));
2918
2915
2919
2916
// Then the split values
2920
- if (UpsizeElementTypes) {
2917
+ if (ValVT.getVectorNumElements () > NumElts) {
2918
+ // If the number of elements has changed, getVectorLoweringShape has upsized
2919
+ // the element types
2920
+ assert ((Isv2x16VT (EltVT) || EltVT == MVT::v4i8) && " Unexpected upsized type." );
2921
2921
// Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2922
2922
// stored as b32s
2923
2923
unsigned NumEltsPerSubVector = EltVT.getVectorNumElements ();
@@ -5229,10 +5229,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5229
5229
5230
5230
assert (ResVT.isVector () && " Vector load must have vector type" );
5231
5231
5232
- auto VectorLoweringParams = getVectorLoweringShape (ResVT);
5233
- if (!VectorLoweringParams )
5232
+ auto NumEltsAndEltVT = getVectorLoweringShape (ResVT);
5233
+ if (!NumEltsAndEltVT )
5234
5234
return ;
5235
- auto [NumElts, EltVT, UpsizeElementTypes ] = VectorLoweringParams .value ();
5235
+ auto [NumElts, EltVT] = NumEltsAndEltVT .value ();
5236
5236
5237
5237
LoadSDNode *LD = cast<LoadSDNode>(N);
5238
5238
@@ -5288,7 +5288,11 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5288
5288
LD->getMemOperand ());
5289
5289
5290
5290
SmallVector<SDValue> ScalarRes;
5291
- if (UpsizeElementTypes) {
5291
+ if (ResVT.getVectorNumElements () > NumElts) {
5292
+ // If the number of elements has changed, getVectorLoweringShape has upsized
5293
+ // the element types
5294
+ assert ((Isv2x16VT (EltVT) || EltVT == MVT::v4i8) &&
5295
+ " Unexpected upsized type." );
5292
5296
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5293
5297
// into individual elements.
5294
5298
for (unsigned i = 0 ; i < NumElts; ++i) {
0 commit comments