@@ -173,16 +173,7 @@ static bool Is16bitsType(MVT VT) {
173
173
// - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
174
174
// the vector
175
175
static std::optional<std::tuple<unsigned int , EVT, bool >>
176
- tryGetVectorLoweringParams (EVT ValVT) {
177
- // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
178
- // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
179
- // vectorized loads/stores with the actual element type for i8/i16 as that
180
- // would require v8/v16 variants that do not exist.
181
- // In order to load/store such vectors efficiently, here in Type Legalization,
182
- // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
183
- // lower to PTX as vectors of b32.
184
- bool UpsizeElementTypes = false ;
185
-
176
+ getVectorLoweringShape (EVT ValVT) {
186
177
if (!ValVT.isVector () || !ValVT.isSimple ())
187
178
return std::nullopt;
188
179
@@ -210,27 +201,29 @@ tryGetVectorLoweringParams(EVT ValVT) {
210
201
case MVT::v4bf16:
211
202
case MVT::v4f32:
212
203
// This is a "native" vector type
213
- break ;
204
+ return std::tuple (NumElts, EltVT, /* UpsizeElementTypes = */ false ) ;
214
205
case MVT::v8i8: // <2 x i8x4>
215
206
case MVT::v8f16: // <4 x f16x2>
216
207
case MVT::v8bf16: // <4 x bf16x2>
217
208
case MVT::v8i16: // <4 x i16x2>
218
209
case MVT::v16i8: // <4 x i8x4>
219
- // This can be upsized into a "native" vector type
220
- UpsizeElementTypes = true ;
221
- break ;
222
- }
210
+ // This can be upsized into a "native" vector type.
211
+ // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
212
+ // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
213
+ // vectorized loads/stores with the actual element type for i8/i16 as that
214
+ // would require v8/v16 variants that do not exist.
215
+ // In order to load/store such vectors efficiently, here in Type
216
+ // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
217
+ // Later, we will lower to PTX as vectors of b32.
223
218
224
- if (UpsizeElementTypes) {
225
219
// Number of elements to pack in one word.
226
220
unsigned NPerWord = 32 / EltVT.getSizeInBits ();
227
- // Word-sized vector.
228
- EltVT = MVT::getVectorVT (EltVT.getSimpleVT (), NPerWord);
229
- // Number of word-sized vectors.
230
- NumElts = NumElts / NPerWord;
221
+ return std::tuple (NumElts / NPerWord,
222
+ MVT::getVectorVT (EltVT.getSimpleVT (), NPerWord),
223
+ /* UpsizeElementTypes = */ true );
231
224
}
232
225
233
- return std::tuple (NumElts, EltVT, UpsizeElementTypes );
226
+ llvm_unreachable ( " All cases should return. " );
234
227
};
235
228
236
229
// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2880,7 +2873,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2880
2873
SDLoc DL (N);
2881
2874
EVT ValVT = Val.getValueType ();
2882
2875
2883
- auto VectorLoweringParams = tryGetVectorLoweringParams (ValVT);
2876
+ auto VectorLoweringParams = getVectorLoweringShape (ValVT);
2884
2877
if (!VectorLoweringParams)
2885
2878
return SDValue ();
2886
2879
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value ();
@@ -5236,7 +5229,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5236
5229
5237
5230
assert (ResVT.isVector () && " Vector load must have vector type" );
5238
5231
5239
- auto VectorLoweringParams = tryGetVectorLoweringParams (ResVT);
5232
+ auto VectorLoweringParams = getVectorLoweringShape (ResVT);
5240
5233
if (!VectorLoweringParams)
5241
5234
return ;
5242
5235
auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value ();
0 commit comments