@@ -370,7 +370,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
370
370
} else if (EltVT.getSimpleVT () == MVT::i8 && NumElts == 2 ) {
371
371
// v2i8 is promoted to v2i16
372
372
NumElts = 1 ;
373
- EltVT = MVT::v2i16 ;
373
+ EltVT = MVT::v2i8 ;
374
374
}
375
375
for (unsigned j = 0 ; j != NumElts; ++j) {
376
376
ValueVTs.push_back (EltVT);
@@ -1065,9 +1065,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1065
1065
MAKE_CASE (NVPTXISD::StoreParamV2)
1066
1066
MAKE_CASE (NVPTXISD::StoreParamV4)
1067
1067
MAKE_CASE (NVPTXISD::MoveParam)
1068
- MAKE_CASE (NVPTXISD::StoreRetval)
1069
- MAKE_CASE (NVPTXISD::StoreRetvalV2)
1070
- MAKE_CASE (NVPTXISD::StoreRetvalV4)
1071
1068
MAKE_CASE (NVPTXISD::UNPACK_VECTOR)
1072
1069
MAKE_CASE (NVPTXISD::BUILD_VECTOR)
1073
1070
MAKE_CASE (NVPTXISD::CallPrototype)
@@ -1438,7 +1435,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1438
1435
}
1439
1436
1440
1437
static ISD::NodeType getExtOpcode (const ISD::ArgFlagsTy &Flags) {
1441
- return Flags.isSExt () ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1438
+ if (Flags.isSExt ())
1439
+ return ISD::SIGN_EXTEND;
1440
+ if (Flags.isZExt ())
1441
+ return ISD::ZERO_EXTEND;
1442
+ return ISD::ANY_EXTEND;
1442
1443
}
1443
1444
1444
1445
SDValue NVPTXTargetLowering::LowerCall (TargetLowering::CallLoweringInfo &CLI,
@@ -3373,10 +3374,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3373
3374
}
3374
3375
InVals.push_back (P);
3375
3376
} else {
3376
- bool aggregateIsPacked = false ;
3377
- if (StructType *STy = dyn_cast<StructType>(Ty))
3378
- aggregateIsPacked = STy->isPacked ();
3379
-
3380
3377
SmallVector<EVT, 16 > VTs;
3381
3378
SmallVector<uint64_t , 16 > Offsets;
3382
3379
ComputePTXValueVTs (*this , DL, Ty, VTs, &Offsets, 0 );
@@ -3389,9 +3386,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3389
3386
const auto VectorInfo = VectorizePTXValueVTs (VTs, Offsets, ArgAlign);
3390
3387
unsigned I = 0 ;
3391
3388
for (const unsigned NumElts : VectorInfo) {
3392
- const EVT EltVT = VTs[I];
3393
3389
// i1 is loaded/stored as i8
3394
- const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT ;
3390
+ const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I] ;
3395
3391
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
3396
3392
// multiple elements.
3397
3393
const unsigned PackingAmt =
@@ -3403,14 +3399,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3403
3399
SDValue VecAddr = DAG.getObjectPtrOffset (
3404
3400
dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
3405
3401
3406
- const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3407
- if (aggregateIsPacked)
3408
- return Align (1 );
3409
- if (NumElts != 1 )
3410
- return std::nullopt;
3411
- Align PartAlign = DAG.getEVTAlign (EltVT);
3412
- return commonAlignment (PartAlign, Offsets[I]);
3413
- }();
3402
+ const MaybeAlign PartAlign = commonAlignment (ArgAlign, Offsets[I]);
3414
3403
SDValue P =
3415
3404
DAG.getLoad (VecVT, dl, Root, VecAddr,
3416
3405
MachinePointerInfo (ADDRESS_SPACE_PARAM), PartAlign,
@@ -3419,23 +3408,22 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3419
3408
if (P.getNode ())
3420
3409
P.getNode ()->setIROrder (Arg.getArgNo () + 1 );
3421
3410
for (const unsigned J : llvm::seq (NumElts)) {
3422
- SDValue Elt = DAG.getNode (LoadVT. isVector () ? ISD::EXTRACT_SUBVECTOR
3423
- : ISD::EXTRACT_VECTOR_ELT,
3424
- dl, LoadVT, P ,
3425
- DAG.getIntPtrConstant (J * PackingAmt, dl));
3411
+ SDValue Elt = DAG.getNode (
3412
+ LoadVT. isVector () ? ISD::EXTRACT_SUBVECTOR
3413
+ : ISD::EXTRACT_VECTOR_ELT ,
3414
+ dl, LoadVT, P, DAG.getVectorIdxConstant (J * PackingAmt, dl));
3426
3415
3427
3416
// Extend or truncate the element if necessary (e.g. an i8 is loaded
3428
3417
// into an i16 register)
3429
- const EVT ExpactedVT = ArgIns[I + J].VT ;
3430
- assert ((Elt.getValueType ().bitsEq (ExpactedVT) ||
3431
- (ExpactedVT.isScalarInteger () &&
3432
- Elt.getValueType ().isScalarInteger ())) &&
3418
+ const EVT ExpectedVT = ArgIns[I + J].VT ;
3419
+ assert ((Elt.getValueType () == ExpectedVT ||
3420
+ (ExpectedVT.isInteger () && Elt.getValueType ().isInteger ())) &&
3433
3421
" Non-integer argument type size mismatch" );
3434
- if (ExpactedVT .bitsGT (Elt.getValueType ()))
3435
- Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpactedVT ,
3422
+ if (ExpectedVT .bitsGT (Elt.getValueType ()))
3423
+ Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpectedVT ,
3436
3424
Elt);
3437
- else if (ExpactedVT .bitsLT (Elt.getValueType ()))
3438
- Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpactedVT , Elt);
3425
+ else if (ExpectedVT .bitsLT (Elt.getValueType ()))
3426
+ Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpectedVT , Elt);
3439
3427
InVals.push_back (Elt);
3440
3428
}
3441
3429
I += NumElts;
@@ -3449,33 +3437,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3449
3437
return Chain;
3450
3438
}
3451
3439
3452
- // Use byte-store when the param adress of the return value is unaligned.
3453
- // This may happen when the return value is a field of a packed structure.
3454
- static SDValue LowerUnalignedStoreRet (SelectionDAG &DAG, SDValue Chain,
3455
- uint64_t Offset, EVT ElementType,
3456
- SDValue RetVal, const SDLoc &dl) {
3457
- // Bit logic only works on integer types
3458
- if (adjustElementType (ElementType))
3459
- RetVal = DAG.getNode (ISD::BITCAST, dl, ElementType, RetVal);
3460
-
3461
- // Store each byte
3462
- for (unsigned i = 0 , n = ElementType.getSizeInBits () / 8 ; i < n; i++) {
3463
- // Shift the byte to the last byte position
3464
- SDValue ShiftVal = DAG.getNode (ISD::SRL, dl, ElementType, RetVal,
3465
- DAG.getConstant (i * 8 , dl, MVT::i32 ));
3466
- SDValue StoreOperands[] = {Chain, DAG.getConstant (Offset + i, dl, MVT::i32 ),
3467
- ShiftVal};
3468
- // Trunc store only the last byte by using
3469
- // st.param.b8
3470
- // The register type can be larger than b8.
3471
- Chain = DAG.getMemIntrinsicNode (NVPTXISD::StoreRetval, dl,
3472
- DAG.getVTList (MVT::Other), StoreOperands,
3473
- MVT::i8 , MachinePointerInfo (), std::nullopt,
3474
- MachineMemOperand::MOStore);
3475
- }
3476
- return Chain;
3477
- }
3478
-
3479
3440
SDValue
3480
3441
NVPTXTargetLowering::LowerReturn (SDValue Chain, CallingConv::ID CallConv,
3481
3442
bool isVarArg,
@@ -3497,10 +3458,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3497
3458
ComputePTXValueVTs (*this , DL, RetTy, VTs, &Offsets);
3498
3459
assert (VTs.size () == OutVals.size () && " Bad return value decomposition" );
3499
3460
3500
- for (const unsigned I : llvm::seq (VTs.size ()))
3501
- if (const auto PromotedVT = PromoteScalarIntegerPTX (VTs[I]))
3502
- VTs[I] = *PromotedVT;
3503
-
3504
3461
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3505
3462
// 32-bits are sign extended or zero extended, depending on whether
3506
3463
// they are signed or unsigned types.
@@ -3512,12 +3469,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3512
3469
assert (!PromoteScalarIntegerPTX (RetVal.getValueType ()) &&
3513
3470
" OutVal type should always be legal" );
3514
3471
3515
- if (ExtendIntegerRetVal) {
3516
- RetVal = DAG.getNode (getExtOpcode (Outs[I].Flags ), dl, MVT::i32 , RetVal);
3517
- } else if (RetVal.getValueSizeInBits () < 16 ) {
3518
- // Use 16-bit registers for small load-stores as it's the
3519
- // smallest general purpose register size supported by NVPTX.
3520
- RetVal = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i16 , RetVal);
3472
+ EVT VTI = VTs[I];
3473
+ if (const auto PromotedVT = PromoteScalarIntegerPTX (VTI))
3474
+ VTI = *PromotedVT;
3475
+
3476
+ const EVT StoreVT =
3477
+ ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
3478
+
3479
+ assert ((RetVal.getValueType () == StoreVT ||
3480
+ (StoreVT.isInteger () && RetVal.getValueType ().isInteger ())) &&
3481
+ " Non-integer argument type size mismatch" );
3482
+ if (StoreVT.bitsGT (RetVal.getValueType ())) {
3483
+ RetVal = DAG.getNode (getExtOpcode (Outs[I].Flags ), dl, StoreVT, RetVal);
3484
+ } else if (StoreVT.bitsLT (RetVal.getValueType ())) {
3485
+ RetVal = DAG.getNode (ISD::TRUNCATE, dl, StoreVT, RetVal);
3521
3486
}
3522
3487
return RetVal;
3523
3488
};
@@ -3526,45 +3491,34 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3526
3491
const auto VectorInfo = VectorizePTXValueVTs (VTs, Offsets, RetAlign);
3527
3492
unsigned I = 0 ;
3528
3493
for (const unsigned NumElts : VectorInfo) {
3529
- const Align CurrentAlign = commonAlignment (RetAlign, Offsets[I]);
3530
- if (NumElts == 1 && RetTy->isAggregateType () &&
3531
- CurrentAlign < DAG.getEVTAlign (VTs[I])) {
3532
- Chain = LowerUnalignedStoreRet (DAG, Chain, Offsets[I], VTs[I],
3533
- GetRetVal (I), dl);
3534
-
3535
- // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3536
- // into the graph, so just move on to the next element.
3537
- I++;
3538
- continue ;
3539
- }
3494
+ const MaybeAlign CurrentAlign = ExtendIntegerRetVal
3495
+ ? MaybeAlign (std::nullopt)
3496
+ : commonAlignment (RetAlign, Offsets[I]);
3540
3497
3541
- SmallVector<SDValue, 6 > StoreOperands{
3542
- Chain, DAG.getConstant (Offsets[I], dl, MVT::i32 )};
3543
-
3544
- for (const unsigned J : llvm::seq (NumElts))
3545
- StoreOperands.push_back (GetRetVal (I + J));
3498
+ SDValue Val;
3499
+ if (NumElts == 1 ) {
3500
+ Val = GetRetVal (I);
3501
+ } else {
3502
+ SmallVector<SDValue, 6 > StoreVals;
3503
+ for (const unsigned J : llvm::seq (NumElts)) {
3504
+ SDValue ValJ = GetRetVal (I + J);
3505
+ if (ValJ.getValueType ().isVector ())
3506
+ DAG.ExtractVectorElements (ValJ, StoreVals);
3507
+ else
3508
+ StoreVals.push_back (ValJ);
3509
+ }
3546
3510
3547
- NVPTXISD::NodeType Op;
3548
- switch (NumElts) {
3549
- case 1 :
3550
- Op = NVPTXISD::StoreRetval;
3551
- break ;
3552
- case 2 :
3553
- Op = NVPTXISD::StoreRetvalV2;
3554
- break ;
3555
- case 4 :
3556
- Op = NVPTXISD::StoreRetvalV4;
3557
- break ;
3558
- default :
3559
- llvm_unreachable (" Invalid vector info." );
3511
+ EVT VT = EVT::getVectorVT (F.getContext (), StoreVals[0 ].getValueType (),
3512
+ StoreVals.size ());
3513
+ Val = DAG.getBuildVector (VT, dl, StoreVals);
3560
3514
}
3561
3515
3562
- // Adjust type of load/store op if we've extended the scalar
3563
- // return value.
3564
- EVT TheStoreType = ExtendIntegerRetVal ? MVT:: i32 : VTs [I];
3565
- Chain = DAG. getMemIntrinsicNode (
3566
- Op, dl, DAG.getVTList (MVT::Other), StoreOperands, TheStoreType ,
3567
- MachinePointerInfo (), CurrentAlign, MachineMemOperand::MOStore );
3516
+ SDValue RetSymbol = DAG. getExternalSymbol ( " func_retval0 " , MVT:: i32 );
3517
+ SDValue Ptr =
3518
+ DAG. getObjectPtrOffset (dl, RetSymbol, TypeSize::getFixed (Offsets [I])) ;
3519
+
3520
+ Chain = DAG.getStore (Chain, dl, Val, Ptr ,
3521
+ MachinePointerInfo (ADDRESS_SPACE_PARAM ), CurrentAlign);
3568
3522
3569
3523
I += NumElts;
3570
3524
}
@@ -5120,19 +5074,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
5120
5074
case NVPTXISD::StoreParamV2:
5121
5075
Opcode = NVPTXISD::StoreParamV4;
5122
5076
break ;
5123
- case NVPTXISD::StoreRetval:
5124
- Opcode = NVPTXISD::StoreRetvalV2;
5125
- break ;
5126
- case NVPTXISD::StoreRetvalV2:
5127
- Opcode = NVPTXISD::StoreRetvalV4;
5128
- break ;
5129
5077
case NVPTXISD::StoreV2:
5130
5078
MemVT = ST->getMemoryVT ();
5131
5079
Opcode = NVPTXISD::StoreV4;
5132
5080
break ;
5133
5081
case NVPTXISD::StoreV4:
5134
5082
case NVPTXISD::StoreParamV4:
5135
- case NVPTXISD::StoreRetvalV4:
5136
5083
case NVPTXISD::StoreV8:
5137
5084
// PTX doesn't support the next doubling of operands
5138
5085
return SDValue ();
@@ -5201,12 +5148,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
5201
5148
return PerformStoreCombineHelper (N, DCI, 3 , 1 );
5202
5149
}
5203
5150
5204
- static SDValue PerformStoreRetvalCombine (SDNode *N,
5205
- TargetLowering::DAGCombinerInfo &DCI) {
5206
- // Operands from the 2nd to the last one are the values to be stored
5207
- return PerformStoreCombineHelper (N, DCI, 2 , 0 );
5208
- }
5209
-
5210
5151
// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
5211
5152
// /
5212
5153
static SDValue PerformADDCombine (SDNode *N,
@@ -5840,10 +5781,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5840
5781
case NVPTXISD::LoadV2:
5841
5782
case NVPTXISD::LoadV4:
5842
5783
return combineUnpackingMovIntoLoad (N, DCI);
5843
- case NVPTXISD::StoreRetval:
5844
- case NVPTXISD::StoreRetvalV2:
5845
- case NVPTXISD::StoreRetvalV4:
5846
- return PerformStoreRetvalCombine (N, DCI);
5847
5784
case NVPTXISD::StoreParam:
5848
5785
case NVPTXISD::StoreParamV2:
5849
5786
case NVPTXISD::StoreParamV4:
0 commit comments