64
64
#include < iterator>
65
65
#include < optional>
66
66
#include < string>
67
+ #include < tuple>
67
68
#include < utility>
68
69
#include < vector>
69
70
@@ -162,14 +163,70 @@ static bool Is16bitsType(MVT VT) {
162
163
VT.SimpleTy == MVT::i16 );
163
164
}
164
165
165
- static auto GetUpsizedNumEltsAndEltVT (unsigned OldNumElts, EVT OldEltVT) {
166
- // Number of elements to pack in one word.
167
- unsigned NPerWord = 32 / OldEltVT.getSizeInBits ();
168
- // Word-sized vector.
169
- EVT NewEltVT = MVT::getVectorVT (OldEltVT.getSimpleVT (), NPerWord);
170
- // Number of word-sized vectors.
171
- unsigned NewNumElts = OldNumElts / NPerWord;
172
- return std::pair (NewNumElts, NewEltVT);
166
+ // When legalizing vector loads/stores, this function is called, which does two things:
167
+ // 1. Determines Whether the vector is something we want to custom lower, std::nullopt is returned if we do not want to custom lower it.
168
+ // 2. If we do want to handle it, returns three parameters:
169
+ // - unsigned int NumElts - The number of elements in the final vector
170
+ // - EVT EltVT - The type of the elements in the final vector
171
+ // - bool UpsizeElementTypes - Whether or not we are upsizing the elements of the vectors
172
+ static std::optional<std::tuple<unsigned int , EVT, bool >> tryGetVectorLoweringParams (EVT ValVT) {
173
+ // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
174
+ // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
175
+ // vectorized loads/stores with the actual element type for i8/i16 as that
176
+ // would require v8/v16 variants that do not exist.
177
+ // In order to load/store such vectors efficiently, here in Type Legalization,
178
+ // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
179
+ // lower to PTX as vectors of b32.
180
+ bool UpsizeElementTypes = false ;
181
+
182
+ if (!ValVT.isVector () || !ValVT.isSimple ())
183
+ return std::nullopt;
184
+
185
+ EVT EltVT = ValVT.getVectorElementType ();
186
+ unsigned NumElts = ValVT.getVectorNumElements ();
187
+
188
+ // We only handle "native" vector sizes for now, e.g. <4 x double> is not
189
+ // legal. We can (and should) split that into 2 stores of <2 x double> here
190
+ // but I'm leaving that as a TODO for now.
191
+ switch (ValVT.getSimpleVT ().SimpleTy ) {
192
+ default :
193
+ return std::nullopt;
194
+ case MVT::v2i8:
195
+ case MVT::v2i16:
196
+ case MVT::v2i32:
197
+ case MVT::v2i64:
198
+ case MVT::v2f16:
199
+ case MVT::v2bf16:
200
+ case MVT::v2f32:
201
+ case MVT::v2f64:
202
+ case MVT::v4i8:
203
+ case MVT::v4i16:
204
+ case MVT::v4i32:
205
+ case MVT::v4f16:
206
+ case MVT::v4bf16:
207
+ case MVT::v4f32:
208
+ // This is a "native" vector type
209
+ break ;
210
+ case MVT::v8i8: // <2 x i8x4>
211
+ case MVT::v8f16: // <4 x f16x2>
212
+ case MVT::v8bf16: // <4 x bf16x2>
213
+ case MVT::v8i16: // <4 x i16x2>
214
+ case MVT::v16i8: // <4 x i8x4>
215
+ // This can be upsized into a "native" vector type
216
+ UpsizeElementTypes = true ;
217
+ break ;
218
+ }
219
+
220
+ if (UpsizeElementTypes) {
221
+ // Number of elements to pack in one word.
222
+ unsigned NPerWord = 32 / EltVT.getSizeInBits ();
223
+ // Word-sized vector.
224
+ EltVT = MVT::getVectorVT (EltVT.getSimpleVT (), NPerWord);
225
+ // Number of word-sized vectors.
226
+ NumElts = NumElts / NPerWord;
227
+ }
228
+
229
+ return std::tuple (NumElts, EltVT, UpsizeElementTypes);
173
230
};
174
231
175
232
// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -2819,130 +2876,81 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2819
2876
SDLoc DL (N);
2820
2877
EVT ValVT = Val.getValueType ();
2821
2878
2822
- // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
2823
- // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
2824
- // vectorized loads/stores with the actual element type for i8/i16 as that
2825
- // would require v8/v16 variants that do not exist.
2826
- // In order to load/store such vectors efficiently, here in Type Legalization,
2827
- // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
2828
- // lower to PTX as vectors of b32.
2829
- bool UpsizeElementTypes = false ;
2830
-
2831
- if (ValVT.isVector ()) {
2832
- // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2833
- // legal. We can (and should) split that into 2 stores of <2 x double> here
2834
- // but I'm leaving that as a TODO for now.
2835
- if (!ValVT.isSimple ())
2836
- return SDValue ();
2837
- switch (ValVT.getSimpleVT ().SimpleTy ) {
2838
- default :
2839
- return SDValue ();
2840
- case MVT::v2i8:
2841
- case MVT::v2i16:
2842
- case MVT::v2i32:
2843
- case MVT::v2i64:
2844
- case MVT::v2f16:
2845
- case MVT::v2bf16:
2846
- case MVT::v2f32:
2847
- case MVT::v2f64:
2848
- case MVT::v4i8:
2849
- case MVT::v4i16:
2850
- case MVT::v4i32:
2851
- case MVT::v4f16:
2852
- case MVT::v4bf16:
2853
- case MVT::v4f32:
2854
- // This is a "native" vector type
2855
- break ;
2856
- case MVT::v8i8: // <2 x i8x4>
2857
- case MVT::v8f16: // <4 x f16x2>
2858
- case MVT::v8bf16: // <4 x bf16x2>
2859
- case MVT::v8i16: // <4 x i16x2>
2860
- case MVT::v16i8: // <4 x i8x4>
2861
- // This can be upsized into a "native" vector type
2862
- UpsizeElementTypes = true ;
2863
- break ;
2864
- }
2865
-
2866
- MemSDNode *MemSD = cast<MemSDNode>(N);
2867
- const DataLayout &TD = DAG.getDataLayout ();
2868
-
2869
- Align Alignment = MemSD->getAlign ();
2870
- Align PrefAlign =
2871
- TD.getPrefTypeAlign (ValVT.getTypeForEVT (*DAG.getContext ()));
2872
- if (Alignment < PrefAlign) {
2873
- // This store is not sufficiently aligned, so bail out and let this vector
2874
- // store be scalarized. Note that we may still be able to emit smaller
2875
- // vector stores. For example, if we are storing a <4 x float> with an
2876
- // alignment of 8, this check will fail but the legalizer will try again
2877
- // with 2 x <2 x float>, which will succeed with an alignment of 8.
2878
- return SDValue ();
2879
- }
2879
+ auto VectorLoweringParams = tryGetVectorLoweringParams (ValVT);
2880
+ if (!VectorLoweringParams)
2881
+ return SDValue ();
2882
+ auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value ();
2880
2883
2881
- unsigned Opcode = 0 ;
2882
- EVT EltVT = ValVT.getVectorElementType ();
2883
- unsigned NumElts = ValVT.getVectorNumElements ();
2884
+ MemSDNode *MemSD = cast<MemSDNode>(N);
2885
+ const DataLayout &TD = DAG.getDataLayout ();
2884
2886
2885
- if (UpsizeElementTypes) {
2886
- std::tie (NumElts, EltVT) = GetUpsizedNumEltsAndEltVT (NumElts, EltVT);
2887
- }
2887
+ Align Alignment = MemSD->getAlign ();
2888
+ Align PrefAlign =
2889
+ TD.getPrefTypeAlign (ValVT.getTypeForEVT (*DAG.getContext ()));
2890
+ if (Alignment < PrefAlign) {
2891
+ // This store is not sufficiently aligned, so bail out and let this vector
2892
+ // store be scalarized. Note that we may still be able to emit smaller
2893
+ // vector stores. For example, if we are storing a <4 x float> with an
2894
+ // alignment of 8, this check will fail but the legalizer will try again
2895
+ // with 2 x <2 x float>, which will succeed with an alignment of 8.
2896
+ return SDValue ();
2897
+ }
2888
2898
2889
- // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2890
- // Therefore, we must ensure the type is legal. For i1 and i8, we set the
2891
- // stored type to i16 and propagate the "real" type as the memory type.
2892
- bool NeedExt = false ;
2893
- if (EltVT.getSizeInBits () < 16 )
2894
- NeedExt = true ;
2899
+ // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2900
+ // Therefore, we must ensure the type is legal. For i1 and i8, we set the
2901
+ // stored type to i16 and propagate the "real" type as the memory type.
2902
+ bool NeedExt = false ;
2903
+ if (EltVT.getSizeInBits () < 16 )
2904
+ NeedExt = true ;
2895
2905
2896
- switch (NumElts) {
2897
- default :
2898
- return SDValue ();
2899
- case 2 :
2900
- Opcode = NVPTXISD::StoreV2;
2901
- break ;
2902
- case 4 :
2903
- Opcode = NVPTXISD::StoreV4;
2904
- break ;
2905
- }
2906
+ unsigned Opcode = 0 ;
2907
+ switch (NumElts) {
2908
+ default :
2909
+ return SDValue ();
2910
+ case 2 :
2911
+ Opcode = NVPTXISD::StoreV2;
2912
+ break ;
2913
+ case 4 :
2914
+ Opcode = NVPTXISD::StoreV4;
2915
+ break ;
2916
+ }
2906
2917
2907
- SmallVector<SDValue, 8 > Ops;
2918
+ SmallVector<SDValue, 8 > Ops;
2908
2919
2909
- // First is the chain
2910
- Ops.push_back (N->getOperand (0 ));
2920
+ // First is the chain
2921
+ Ops.push_back (N->getOperand (0 ));
2911
2922
2912
- if (UpsizeElementTypes) {
2913
- // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2914
- // stored as b32s
2915
- unsigned NumEltsPerSubVector = EltVT.getVectorNumElements ();
2916
- for (unsigned i = 0 ; i < NumElts; ++i) {
2917
- SmallVector<SDValue, 4 > SubVectorElts;
2918
- DAG.ExtractVectorElements (Val, SubVectorElts, i * NumEltsPerSubVector,
2919
- NumEltsPerSubVector);
2920
- SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
2921
- Ops.push_back (SubVector);
2922
- }
2923
- } else {
2924
- // Then the split values
2925
- for (unsigned i = 0 ; i < NumElts; ++i) {
2926
- SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2927
- DAG.getIntPtrConstant (i, DL));
2928
- if (NeedExt)
2929
- ExtVal = DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i16 , ExtVal);
2930
- Ops.push_back (ExtVal);
2931
- }
2923
+ // Then the split values
2924
+ if (UpsizeElementTypes) {
2925
+ // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2926
+ // stored as b32s
2927
+ unsigned NumEltsPerSubVector = EltVT.getVectorNumElements ();
2928
+ for (unsigned i = 0 ; i < NumElts; ++i) {
2929
+ SmallVector<SDValue, 4 > SubVectorElts;
2930
+ DAG.ExtractVectorElements (Val, SubVectorElts, i * NumEltsPerSubVector,
2931
+ NumEltsPerSubVector);
2932
+ SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
2933
+ Ops.push_back (SubVector);
2932
2934
}
2935
+ } else {
2936
+ for (unsigned i = 0 ; i < NumElts; ++i) {
2937
+ SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2938
+ DAG.getIntPtrConstant (i, DL));
2939
+ if (NeedExt)
2940
+ ExtVal = DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i16 , ExtVal);
2941
+ Ops.push_back (ExtVal);
2942
+ }
2943
+ }
2933
2944
2934
- // Then any remaining arguments
2935
- Ops.append (N->op_begin () + 2 , N->op_end ());
2936
-
2937
- SDValue NewSt =
2938
- DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
2939
- MemSD->getMemoryVT (), MemSD->getMemOperand ());
2945
+ // Then any remaining arguments
2946
+ Ops.append (N->op_begin () + 2 , N->op_end ());
2940
2947
2941
- // return DCI.CombineTo(N, NewSt, true);
2942
- return NewSt;
2943
- }
2948
+ SDValue NewSt =
2949
+ DAG. getMemIntrinsicNode (Opcode, DL, DAG. getVTList (MVT::Other), Ops,
2950
+ MemSD-> getMemoryVT (), MemSD-> getMemOperand ());
2944
2951
2945
- return SDValue ();
2952
+ // return DCI.CombineTo(N, NewSt, true);
2953
+ return NewSt;
2946
2954
}
2947
2955
2948
2956
// st i1 v, addr
@@ -5225,46 +5233,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5225
5233
5226
5234
assert (ResVT.isVector () && " Vector load must have vector type" );
5227
5235
5228
- // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
5229
- // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
5230
- // vectorized loads/stores with the actual element type for i8/i16 as that
5231
- // would require v8/v16 variants that do not exist.
5232
- // In order to load/store such vectors efficiently, here in Type Legalization,
5233
- // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
5234
- // lower to PTX as vectors of b32.
5235
- bool UpsizeElementTypes = false ;
5236
-
5237
- // We only handle "native" vector sizes for now, e.g. <4 x double> is not
5238
- // legal. We can (and should) split that into 2 loads of <2 x double> here
5239
- // but I'm leaving that as a TODO for now.
5240
- assert (ResVT.isSimple () && " Can only handle simple types" );
5241
- switch (ResVT.getSimpleVT ().SimpleTy ) {
5242
- default :
5236
+ auto VectorLoweringParams = tryGetVectorLoweringParams (ResVT);
5237
+ if (!VectorLoweringParams)
5243
5238
return ;
5244
- case MVT::v2i8:
5245
- case MVT::v2i16:
5246
- case MVT::v2i32:
5247
- case MVT::v2i64:
5248
- case MVT::v2f16:
5249
- case MVT::v2f32:
5250
- case MVT::v2f64:
5251
- case MVT::v4i8:
5252
- case MVT::v4i16:
5253
- case MVT::v4i32:
5254
- case MVT::v4f16:
5255
- case MVT::v4bf16:
5256
- case MVT::v4f32:
5257
- // This is a "native" vector type
5258
- break ;
5259
- case MVT::v8i8: // <2 x i8x4>
5260
- case MVT::v8f16: // <4 x f16x2>
5261
- case MVT::v8bf16: // <4 x bf16x2>
5262
- case MVT::v8i16: // <4 x i16x2>
5263
- case MVT::v16i8: // <4 x i8x4>
5264
- // This can be upsized into a "native" vector type
5265
- UpsizeElementTypes = true ;
5266
- break ;
5267
- }
5239
+ auto [NumElts, EltVT, UpsizeElementTypes] = VectorLoweringParams.value ();
5268
5240
5269
5241
LoadSDNode *LD = cast<LoadSDNode>(N);
5270
5242
@@ -5281,13 +5253,6 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5281
5253
return ;
5282
5254
}
5283
5255
5284
- EVT EltVT = ResVT.getVectorElementType ();
5285
- unsigned NumElts = ResVT.getVectorNumElements ();
5286
-
5287
- if (UpsizeElementTypes) {
5288
- std::tie (NumElts, EltVT) = GetUpsizedNumEltsAndEltVT (NumElts, EltVT);
5289
- }
5290
-
5291
5256
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
5292
5257
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
5293
5258
// loaded type to i16 and propagate the "real" type as the memory type.
0 commit comments