Skip to content

Commit f03782d

Browse files
authored
[NVPTX] Fixup v2i8 parameter and return lowering (#145585)
This change fixes v2i8 lowering for parameters and returned values. As part of this work, I move the lowering for return values to use generic ISD::STORE nodes as these are more flexible and have existing legalization handling. Note that calling a function with v2i8 arguments or returns is still not working but this is left for a subsequent change as this MR is already fairly large. Partially addresses #128853
1 parent 7613c24 commit f03782d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1138
-1249
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
151151
if (tryLoadParam(N))
152152
return;
153153
break;
154-
case NVPTXISD::StoreRetval:
155-
case NVPTXISD::StoreRetvalV2:
156-
case NVPTXISD::StoreRetvalV4:
157-
if (tryStoreRetval(N))
158-
return;
159-
break;
160154
case NVPTXISD::StoreParam:
161155
case NVPTXISD::StoreParamV2:
162156
case NVPTXISD::StoreParamV4:
@@ -1504,84 +1498,6 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
15041498
return true;
15051499
}
15061500

1507-
bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
1508-
SDLoc DL(N);
1509-
SDValue Chain = N->getOperand(0);
1510-
SDValue Offset = N->getOperand(1);
1511-
unsigned OffsetVal = Offset->getAsZExtVal();
1512-
MemSDNode *Mem = cast<MemSDNode>(N);
1513-
1514-
// How many elements do we have?
1515-
unsigned NumElts = 1;
1516-
switch (N->getOpcode()) {
1517-
default:
1518-
return false;
1519-
case NVPTXISD::StoreRetval:
1520-
NumElts = 1;
1521-
break;
1522-
case NVPTXISD::StoreRetvalV2:
1523-
NumElts = 2;
1524-
break;
1525-
case NVPTXISD::StoreRetvalV4:
1526-
NumElts = 4;
1527-
break;
1528-
}
1529-
1530-
// Build vector of operands
1531-
SmallVector<SDValue, 6> Ops;
1532-
for (unsigned i = 0; i < NumElts; ++i)
1533-
Ops.push_back(N->getOperand(i + 2));
1534-
Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain});
1535-
1536-
// Determine target opcode
1537-
// If we have an i1, use an 8-bit store. The lowering code in
1538-
// NVPTXISelLowering will have already emitted an upcast.
1539-
std::optional<unsigned> Opcode = 0;
1540-
switch (NumElts) {
1541-
default:
1542-
return false;
1543-
case 1:
1544-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1545-
NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
1546-
NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64);
1547-
if (Opcode == NVPTX::StoreRetvalI8) {
1548-
// Fine tune the opcode depending on the size of the operand.
1549-
// This helps to avoid creating redundant COPY instructions in
1550-
// InstrEmitter::AddRegisterOperand().
1551-
switch (Ops[0].getSimpleValueType().SimpleTy) {
1552-
default:
1553-
break;
1554-
case MVT::i32:
1555-
Opcode = NVPTX::StoreRetvalI8TruncI32;
1556-
break;
1557-
case MVT::i64:
1558-
Opcode = NVPTX::StoreRetvalI8TruncI64;
1559-
break;
1560-
}
1561-
}
1562-
break;
1563-
case 2:
1564-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1565-
NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
1566-
NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64);
1567-
break;
1568-
case 4:
1569-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1570-
NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
1571-
NVPTX::StoreRetvalV4I32, {/* no v4i64 */});
1572-
break;
1573-
}
1574-
if (!Opcode)
1575-
return false;
1576-
1577-
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
1578-
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1579-
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
1580-
1581-
ReplaceNode(N, Ret);
1582-
return true;
1583-
}
1584-
15851501
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
15861502
#define getOpcV2H(ty, opKind0, opKind1) \
15871503
NVPTX::StoreParamV2##ty##_##opKind0##opKind1

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7979
bool tryStore(SDNode *N);
8080
bool tryStoreVector(SDNode *N);
8181
bool tryLoadParam(SDNode *N);
82-
bool tryStoreRetval(SDNode *N);
8382
bool tryStoreParam(SDNode *N);
8483
bool tryFence(SDNode *N);
8584
void SelectAddrSpaceCast(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 57 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
370370
} else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
371371
// v2i8 is promoted to v2i16
372372
NumElts = 1;
373-
EltVT = MVT::v2i16;
373+
EltVT = MVT::v2i8;
374374
}
375375
for (unsigned j = 0; j != NumElts; ++j) {
376376
ValueVTs.push_back(EltVT);
@@ -1065,9 +1065,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10651065
MAKE_CASE(NVPTXISD::StoreParamV2)
10661066
MAKE_CASE(NVPTXISD::StoreParamV4)
10671067
MAKE_CASE(NVPTXISD::MoveParam)
1068-
MAKE_CASE(NVPTXISD::StoreRetval)
1069-
MAKE_CASE(NVPTXISD::StoreRetvalV2)
1070-
MAKE_CASE(NVPTXISD::StoreRetvalV4)
10711068
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
10721069
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
10731070
MAKE_CASE(NVPTXISD::CallPrototype)
@@ -1438,7 +1435,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
14381435
}
14391436

14401437
static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1441-
return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1438+
if (Flags.isSExt())
1439+
return ISD::SIGN_EXTEND;
1440+
if (Flags.isZExt())
1441+
return ISD::ZERO_EXTEND;
1442+
return ISD::ANY_EXTEND;
14421443
}
14431444

14441445
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
@@ -3373,10 +3374,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33733374
}
33743375
InVals.push_back(P);
33753376
} else {
3376-
bool aggregateIsPacked = false;
3377-
if (StructType *STy = dyn_cast<StructType>(Ty))
3378-
aggregateIsPacked = STy->isPacked();
3379-
33803377
SmallVector<EVT, 16> VTs;
33813378
SmallVector<uint64_t, 16> Offsets;
33823379
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
@@ -3389,9 +3386,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33893386
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
33903387
unsigned I = 0;
33913388
for (const unsigned NumElts : VectorInfo) {
3392-
const EVT EltVT = VTs[I];
33933389
// i1 is loaded/stored as i8
3394-
const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
3390+
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
33953391
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
33963392
// multiple elements.
33973393
const unsigned PackingAmt =
@@ -3403,14 +3399,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34033399
SDValue VecAddr = DAG.getObjectPtrOffset(
34043400
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
34053401

3406-
const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3407-
if (aggregateIsPacked)
3408-
return Align(1);
3409-
if (NumElts != 1)
3410-
return std::nullopt;
3411-
Align PartAlign = DAG.getEVTAlign(EltVT);
3412-
return commonAlignment(PartAlign, Offsets[I]);
3413-
}();
3402+
const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
34143403
SDValue P =
34153404
DAG.getLoad(VecVT, dl, Root, VecAddr,
34163405
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
@@ -3419,23 +3408,22 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34193408
if (P.getNode())
34203409
P.getNode()->setIROrder(Arg.getArgNo() + 1);
34213410
for (const unsigned J : llvm::seq(NumElts)) {
3422-
SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3423-
: ISD::EXTRACT_VECTOR_ELT,
3424-
dl, LoadVT, P,
3425-
DAG.getIntPtrConstant(J * PackingAmt, dl));
3411+
SDValue Elt = DAG.getNode(
3412+
LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3413+
: ISD::EXTRACT_VECTOR_ELT,
3414+
dl, LoadVT, P, DAG.getVectorIdxConstant(J * PackingAmt, dl));
34263415

34273416
// Extend or truncate the element if necessary (e.g. an i8 is loaded
34283417
// into an i16 register)
3429-
const EVT ExpactedVT = ArgIns[I + J].VT;
3430-
assert((Elt.getValueType().bitsEq(ExpactedVT) ||
3431-
(ExpactedVT.isScalarInteger() &&
3432-
Elt.getValueType().isScalarInteger())) &&
3418+
const EVT ExpectedVT = ArgIns[I + J].VT;
3419+
assert((Elt.getValueType() == ExpectedVT ||
3420+
(ExpectedVT.isInteger() && Elt.getValueType().isInteger())) &&
34333421
"Non-integer argument type size mismatch");
3434-
if (ExpactedVT.bitsGT(Elt.getValueType()))
3435-
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
3422+
if (ExpectedVT.bitsGT(Elt.getValueType()))
3423+
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpectedVT,
34363424
Elt);
3437-
else if (ExpactedVT.bitsLT(Elt.getValueType()))
3438-
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
3425+
else if (ExpectedVT.bitsLT(Elt.getValueType()))
3426+
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Elt);
34393427
InVals.push_back(Elt);
34403428
}
34413429
I += NumElts;
@@ -3449,33 +3437,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34493437
return Chain;
34503438
}
34513439

3452-
// Use byte-store when the param adress of the return value is unaligned.
3453-
// This may happen when the return value is a field of a packed structure.
3454-
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
3455-
uint64_t Offset, EVT ElementType,
3456-
SDValue RetVal, const SDLoc &dl) {
3457-
// Bit logic only works on integer types
3458-
if (adjustElementType(ElementType))
3459-
RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3460-
3461-
// Store each byte
3462-
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3463-
// Shift the byte to the last byte position
3464-
SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3465-
DAG.getConstant(i * 8, dl, MVT::i32));
3466-
SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3467-
ShiftVal};
3468-
// Trunc store only the last byte by using
3469-
// st.param.b8
3470-
// The register type can be larger than b8.
3471-
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
3472-
DAG.getVTList(MVT::Other), StoreOperands,
3473-
MVT::i8, MachinePointerInfo(), std::nullopt,
3474-
MachineMemOperand::MOStore);
3475-
}
3476-
return Chain;
3477-
}
3478-
34793440
SDValue
34803441
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
34813442
bool isVarArg,
@@ -3497,10 +3458,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
34973458
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
34983459
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
34993460

3500-
for (const unsigned I : llvm::seq(VTs.size()))
3501-
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
3502-
VTs[I] = *PromotedVT;
3503-
35043461
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
35053462
// 32-bits are sign extended or zero extended, depending on whether
35063463
// they are signed or unsigned types.
@@ -3512,12 +3469,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35123469
assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
35133470
"OutVal type should always be legal");
35143471

3515-
if (ExtendIntegerRetVal) {
3516-
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
3517-
} else if (RetVal.getValueSizeInBits() < 16) {
3518-
// Use 16-bit registers for small load-stores as it's the
3519-
// smallest general purpose register size supported by NVPTX.
3520-
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3472+
EVT VTI = VTs[I];
3473+
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTI))
3474+
VTI = *PromotedVT;
3475+
3476+
const EVT StoreVT =
3477+
ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
3478+
3479+
assert((RetVal.getValueType() == StoreVT ||
3480+
(StoreVT.isInteger() && RetVal.getValueType().isInteger())) &&
3481+
"Non-integer argument type size mismatch");
3482+
if (StoreVT.bitsGT(RetVal.getValueType())) {
3483+
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, StoreVT, RetVal);
3484+
} else if (StoreVT.bitsLT(RetVal.getValueType())) {
3485+
RetVal = DAG.getNode(ISD::TRUNCATE, dl, StoreVT, RetVal);
35213486
}
35223487
return RetVal;
35233488
};
@@ -3526,45 +3491,34 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35263491
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
35273492
unsigned I = 0;
35283493
for (const unsigned NumElts : VectorInfo) {
3529-
const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
3530-
if (NumElts == 1 && RetTy->isAggregateType() &&
3531-
CurrentAlign < DAG.getEVTAlign(VTs[I])) {
3532-
Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
3533-
GetRetVal(I), dl);
3534-
3535-
// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3536-
// into the graph, so just move on to the next element.
3537-
I++;
3538-
continue;
3539-
}
3494+
const MaybeAlign CurrentAlign = ExtendIntegerRetVal
3495+
? MaybeAlign(std::nullopt)
3496+
: commonAlignment(RetAlign, Offsets[I]);
35403497

3541-
SmallVector<SDValue, 6> StoreOperands{
3542-
Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};
3543-
3544-
for (const unsigned J : llvm::seq(NumElts))
3545-
StoreOperands.push_back(GetRetVal(I + J));
3498+
SDValue Val;
3499+
if (NumElts == 1) {
3500+
Val = GetRetVal(I);
3501+
} else {
3502+
SmallVector<SDValue, 6> StoreVals;
3503+
for (const unsigned J : llvm::seq(NumElts)) {
3504+
SDValue ValJ = GetRetVal(I + J);
3505+
if (ValJ.getValueType().isVector())
3506+
DAG.ExtractVectorElements(ValJ, StoreVals);
3507+
else
3508+
StoreVals.push_back(ValJ);
3509+
}
35463510

3547-
NVPTXISD::NodeType Op;
3548-
switch (NumElts) {
3549-
case 1:
3550-
Op = NVPTXISD::StoreRetval;
3551-
break;
3552-
case 2:
3553-
Op = NVPTXISD::StoreRetvalV2;
3554-
break;
3555-
case 4:
3556-
Op = NVPTXISD::StoreRetvalV4;
3557-
break;
3558-
default:
3559-
llvm_unreachable("Invalid vector info.");
3511+
EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
3512+
StoreVals.size());
3513+
Val = DAG.getBuildVector(VT, dl, StoreVals);
35603514
}
35613515

3562-
// Adjust type of load/store op if we've extended the scalar
3563-
// return value.
3564-
EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
3565-
Chain = DAG.getMemIntrinsicNode(
3566-
Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3567-
MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
3516+
SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
3517+
SDValue Ptr =
3518+
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
3519+
3520+
Chain = DAG.getStore(Chain, dl, Val, Ptr,
3521+
MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
35683522

35693523
I += NumElts;
35703524
}
@@ -5120,19 +5074,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
51205074
case NVPTXISD::StoreParamV2:
51215075
Opcode = NVPTXISD::StoreParamV4;
51225076
break;
5123-
case NVPTXISD::StoreRetval:
5124-
Opcode = NVPTXISD::StoreRetvalV2;
5125-
break;
5126-
case NVPTXISD::StoreRetvalV2:
5127-
Opcode = NVPTXISD::StoreRetvalV4;
5128-
break;
51295077
case NVPTXISD::StoreV2:
51305078
MemVT = ST->getMemoryVT();
51315079
Opcode = NVPTXISD::StoreV4;
51325080
break;
51335081
case NVPTXISD::StoreV4:
51345082
case NVPTXISD::StoreParamV4:
5135-
case NVPTXISD::StoreRetvalV4:
51365083
case NVPTXISD::StoreV8:
51375084
// PTX doesn't support the next doubling of operands
51385085
return SDValue();
@@ -5201,12 +5148,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
52015148
return PerformStoreCombineHelper(N, DCI, 3, 1);
52025149
}
52035150

5204-
static SDValue PerformStoreRetvalCombine(SDNode *N,
5205-
TargetLowering::DAGCombinerInfo &DCI) {
5206-
// Operands from the 2nd to the last one are the values to be stored
5207-
return PerformStoreCombineHelper(N, DCI, 2, 0);
5208-
}
5209-
52105151
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
52115152
///
52125153
static SDValue PerformADDCombine(SDNode *N,
@@ -5840,10 +5781,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58405781
case NVPTXISD::LoadV2:
58415782
case NVPTXISD::LoadV4:
58425783
return combineUnpackingMovIntoLoad(N, DCI);
5843-
case NVPTXISD::StoreRetval:
5844-
case NVPTXISD::StoreRetvalV2:
5845-
case NVPTXISD::StoreRetvalV4:
5846-
return PerformStoreRetvalCombine(N, DCI);
58475784
case NVPTXISD::StoreParam:
58485785
case NVPTXISD::StoreParamV2:
58495786
case NVPTXISD::StoreParamV4:

0 commit comments

Comments
 (0)