@@ -409,6 +409,13 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
409
409
return VectorInfo;
410
410
}
411
411
412
+ static SDValue MaybeBitcast (SelectionDAG &DAG, SDLoc DL, EVT VT,
413
+ SDValue Value) {
414
+ if (Value->getValueType (0 ) == VT)
415
+ return Value;
416
+ return DAG.getNode (ISD::BITCAST, DL, VT, Value);
417
+ }
418
+
412
419
// NVPTXTargetLowering Constructor.
413
420
NVPTXTargetLowering::NVPTXTargetLowering (const NVPTXTargetMachine &TM,
414
421
const NVPTXSubtarget &STI)
@@ -551,6 +558,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
551
558
setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
552
559
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
553
560
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
561
+
562
+ // Custom conversions to/from v2i8.
563
+ setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
564
+
554
565
// Only logical ops can be done on v4i8 directly, others must be done
555
566
// elementwise.
556
567
setOperationAction (
@@ -2309,6 +2320,30 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2309
2320
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2310
2321
}
2311
2322
2323
+ SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2324
+ // Handle bitcasting from v2i8 without hitting the default promotion
2325
+ // strategy which goes through stack memory.
2326
+ EVT FromVT = Op->getOperand (0 )->getValueType (0 );
2327
+ if (FromVT != MVT::v2i8) {
2328
+ return Op;
2329
+ }
2330
+
2331
+ // Pack vector elements into i16 and bitcast to final type
2332
+ SDLoc DL (Op);
2333
+ SDValue Vec0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2334
+ Op->getOperand (0 ), DAG.getIntPtrConstant (0 , DL));
2335
+ SDValue Vec1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2336
+ Op->getOperand (0 ), DAG.getIntPtrConstant (1 , DL));
2337
+ SDValue Extend0 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec0);
2338
+ SDValue Extend1 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec1);
2339
+ SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
2340
+ SDValue AsInt = DAG.getNode (
2341
+ ISD::OR, DL, MVT::i16 ,
2342
+ {Extend0, DAG.getNode (ISD::SHL, DL, MVT::i16 , {Extend1, Const8})});
2343
+ EVT ToVT = Op->getValueType (0 );
2344
+ return MaybeBitcast (DAG, DL, ToVT, AsInt);
2345
+ }
2346
+
2312
2347
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2313
2348
// would get lowered as two constant loads and vector-packing move.
2314
2349
// Instead we want just a constant move:
@@ -2817,6 +2852,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2817
2852
return Op;
2818
2853
case ISD::BUILD_VECTOR:
2819
2854
return LowerBUILD_VECTOR (Op, DAG);
2855
+ case ISD::BITCAST:
2856
+ return LowerBITCAST (Op, DAG);
2820
2857
case ISD::EXTRACT_SUBVECTOR:
2821
2858
return Op;
2822
2859
case ISD::EXTRACT_VECTOR_ELT:
@@ -6127,6 +6164,28 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6127
6164
return SDValue ();
6128
6165
}
6129
6166
6167
+ static void ReplaceBITCAST (SDNode *Node, SelectionDAG &DAG,
6168
+ SmallVectorImpl<SDValue> &Results) {
6169
+ // Handle bitcasting to v2i8 without hitting the default promotion
6170
+ // strategy which goes through stack memory.
6171
+ SDValue Op (Node, 0 );
6172
+ EVT ToVT = Op->getValueType (0 );
6173
+ if (ToVT != MVT::v2i8) {
6174
+ return ;
6175
+ }
6176
+
6177
+ // Bitcast to i16 and unpack elements into a vector
6178
+ SDLoc DL (Node);
6179
+ SDValue AsInt = MaybeBitcast (DAG, DL, MVT::i16 , Op->getOperand (0 ));
6180
+ SDValue Vec0 = DAG.getNode (ISD::TRUNCATE, DL, MVT::i8 , AsInt);
6181
+ SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
6182
+ SDValue Vec1 =
6183
+ DAG.getNode (ISD::TRUNCATE, DL, MVT::i8 ,
6184
+ DAG.getNode (ISD::SRL, DL, MVT::i16 , {AsInt, Const8}));
6185
+ Results.push_back (
6186
+ DAG.getNode (ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
6187
+ }
6188
+
6130
6189
// / ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
6131
6190
static void ReplaceLoadVector (SDNode *N, SelectionDAG &DAG,
6132
6191
SmallVectorImpl<SDValue> &Results) {
@@ -6412,6 +6471,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
6412
6471
switch (N->getOpcode ()) {
6413
6472
default :
6414
6473
report_fatal_error (" Unhandled custom legalization" );
6474
+ case ISD::BITCAST:
6475
+ ReplaceBITCAST (N, DAG, Results);
6476
+ return ;
6415
6477
case ISD::LOAD:
6416
6478
ReplaceLoadVector (N, DAG, Results);
6417
6479
return ;
0 commit comments