Skip to content

Commit d2f0b3d

Browse files
committed
[SelectionDAG] Properly handle legalization of vector types for [US]CMP nodes
1 parent 81dfa54 commit d2f0b3d

File tree

9 files changed

+2425
-133
lines changed

9 files changed

+2425
-133
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
784784
void ScalarizeVectorResult(SDNode *N, unsigned ResNo);
785785
SDValue ScalarizeVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo);
786786
SDValue ScalarizeVecRes_BinOp(SDNode *N);
787+
SDValue ScalarizeVecRes_CMP(SDNode *N);
787788
SDValue ScalarizeVecRes_TernaryOp(SDNode *N);
788789
SDValue ScalarizeVecRes_UnaryOp(SDNode *N);
789790
SDValue ScalarizeVecRes_StrictFPOp(SDNode *N);
@@ -827,6 +828,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
827828
SDValue ScalarizeVecOp_STRICT_FP_EXTEND(SDNode *N);
828829
SDValue ScalarizeVecOp_VECREDUCE(SDNode *N);
829830
SDValue ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N);
831+
SDValue ScalarizeVecOp_CMP(SDNode *N);
830832

831833
//===--------------------------------------------------------------------===//
832834
// Vector Splitting Support: LegalizeVectorTypes.cpp
@@ -857,6 +859,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
857859
void SplitVectorResult(SDNode *N, unsigned ResNo);
858860
void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
859861
void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
862+
void SplitVecRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi);
860863
void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
861864
void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
862865
void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -920,6 +923,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
920923
SDValue SplitVecOp_VSETCC(SDNode *N);
921924
SDValue SplitVecOp_FP_ROUND(SDNode *N);
922925
SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
926+
SDValue SplitVecOp_CMP(SDNode *N);
923927
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
924928
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
925929

@@ -987,6 +991,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
987991

988992
SDValue WidenVecRes_Ternary(SDNode *N);
989993
SDValue WidenVecRes_Binary(SDNode *N);
994+
SDValue WidenVecRes_CMP(SDNode *N);
990995
SDValue WidenVecRes_BinaryCanTrap(SDNode *N);
991996
SDValue WidenVecRes_BinaryWithExtraScalarOp(SDNode *N);
992997
SDValue WidenVecRes_StrictFP(SDNode *N);
@@ -1006,6 +1011,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
10061011
SDValue WidenVecOp_BITCAST(SDNode *N);
10071012
SDValue WidenVecOp_CONCAT_VECTORS(SDNode *N);
10081013
SDValue WidenVecOp_EXTEND(SDNode *N);
1014+
SDValue WidenVecOp_CMP(SDNode *N);
10091015
SDValue WidenVecOp_EXTRACT_VECTOR_ELT(SDNode *N);
10101016
SDValue WidenVecOp_INSERT_SUBVECTOR(SDNode *N);
10111017
SDValue WidenVecOp_EXTRACT_SUBVECTOR(SDNode *N);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
442442
case ISD::FP_TO_SINT_SAT:
443443
case ISD::FP_TO_UINT_SAT:
444444
case ISD::MGATHER:
445+
case ISD::SCMP:
446+
case ISD::UCMP:
445447
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
446448
break;
447449
case ISD::SMULFIX:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
164164
case ISD::ROTR:
165165
R = ScalarizeVecRes_BinOp(N);
166166
break;
167+
168+
case ISD::SCMP:
169+
case ISD::UCMP:
170+
R = ScalarizeVecRes_CMP(N);
171+
break;
172+
167173
case ISD::FMA:
168174
case ISD::FSHL:
169175
case ISD::FSHR:
@@ -213,6 +219,27 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_BinOp(SDNode *N) {
213219
LHS.getValueType(), LHS, RHS, N->getFlags());
214220
}
215221

222+
SDValue DAGTypeLegalizer::ScalarizeVecRes_CMP(SDNode *N) {
223+
SDLoc DL(N);
224+
225+
SDValue LHS = N->getOperand(0);
226+
SDValue RHS = N->getOperand(1);
227+
if (getTypeAction(LHS.getValueType()) ==
228+
TargetLowering::TypeScalarizeVector) {
229+
LHS = GetScalarizedVector(LHS);
230+
RHS = GetScalarizedVector(RHS);
231+
} else {
232+
EVT VT = LHS.getValueType().getVectorElementType();
233+
LHS = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, LHS,
234+
DAG.getVectorIdxConstant(0, DL));
235+
RHS = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, RHS,
236+
DAG.getVectorIdxConstant(0, DL));
237+
}
238+
239+
return DAG.getNode(N->getOpcode(), SDLoc(N),
240+
N->getValueType(0).getVectorElementType(), LHS, RHS);
241+
}
242+
216243
SDValue DAGTypeLegalizer::ScalarizeVecRes_TernaryOp(SDNode *N) {
217244
SDValue Op0 = GetScalarizedVector(N->getOperand(0));
218245
SDValue Op1 = GetScalarizedVector(N->getOperand(1));
@@ -741,6 +768,10 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
741768
case ISD::VECREDUCE_SEQ_FMUL:
742769
Res = ScalarizeVecOp_VECREDUCE_SEQ(N);
743770
break;
771+
case ISD::SCMP:
772+
case ISD::UCMP:
773+
Res = ScalarizeVecOp_CMP(N);
774+
break;
744775
}
745776

746777
// If the result is null, the sub-method took care of registering results etc.
@@ -961,6 +992,12 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) {
961992
AccOp, Op, N->getFlags());
962993
}
963994

995+
SDValue DAGTypeLegalizer::ScalarizeVecOp_CMP(SDNode *N) {
996+
SDValue LHS = GetScalarizedVector(N->getOperand(0));
997+
SDValue RHS = GetScalarizedVector(N->getOperand(1));
998+
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
999+
}
1000+
9641001
//===----------------------------------------------------------------------===//
9651002
// Result Vector Splitting
9661003
//===----------------------------------------------------------------------===//
@@ -1184,6 +1221,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11841221
SplitVecRes_TernaryOp(N, Lo, Hi);
11851222
break;
11861223

1224+
case ISD::SCMP: case ISD::UCMP:
1225+
SplitVecRes_CMP(N, Lo, Hi);
1226+
break;
1227+
11871228
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
11881229
case ISD::STRICT_##DAGN:
11891230
#include "llvm/IR/ConstrainedOps.def"
@@ -1327,6 +1368,27 @@ void DAGTypeLegalizer::SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo,
13271368
{Op0Hi, Op1Hi, Op2Hi, MaskHi, EVLHi}, Flags);
13281369
}
13291370

1371+
void DAGTypeLegalizer::SplitVecRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi) {
1372+
LLVMContext &Ctxt = *DAG.getContext();
1373+
SDLoc dl(N);
1374+
1375+
SDValue LHS = N->getOperand(0);
1376+
SDValue RHS = N->getOperand(1);
1377+
1378+
SDValue LHSLo, LHSHi, RHSLo, RHSHi;
1379+
if (getTypeAction(LHS.getValueType()) == TargetLowering::TypeSplitVector) {
1380+
GetSplitVector(LHS, LHSLo, LHSHi);
1381+
GetSplitVector(RHS, RHSLo, RHSHi);
1382+
} else {
1383+
std::tie(LHSLo, LHSHi) = DAG.SplitVector(LHS, dl);
1384+
std::tie(RHSLo, RHSHi) = DAG.SplitVector(RHS, dl);
1385+
}
1386+
1387+
EVT SplitResVT = N->getValueType(0).getHalfNumVectorElementsVT(Ctxt);
1388+
Lo = DAG.getNode(N->getOpcode(), dl, SplitResVT, LHSLo, RHSLo);
1389+
Hi = DAG.getNode(N->getOpcode(), dl, SplitResVT, LHSHi, RHSHi);
1390+
}
1391+
13301392
void DAGTypeLegalizer::SplitVecRes_FIX(SDNode *N, SDValue &Lo, SDValue &Hi) {
13311393
SDValue LHSLo, LHSHi;
13321394
GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
@@ -3054,6 +3116,11 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
30543116
Res = SplitVecOp_FPOpDifferentTypes(N);
30553117
break;
30563118

3119+
case ISD::SCMP:
3120+
case ISD::UCMP:
3121+
Res = SplitVecOp_CMP(N);
3122+
break;
3123+
30573124
case ISD::ANY_EXTEND_VECTOR_INREG:
30583125
case ISD::SIGN_EXTEND_VECTOR_INREG:
30593126
case ISD::ZERO_EXTEND_VECTOR_INREG:
@@ -4043,6 +4110,25 @@ SDValue DAGTypeLegalizer::SplitVecOp_FPOpDifferentTypes(SDNode *N) {
40434110
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Lo, Hi);
40444111
}
40454112

4113+
SDValue DAGTypeLegalizer::SplitVecOp_CMP(SDNode *N) {
4114+
LLVMContext &Ctxt = *DAG.getContext();
4115+
SDLoc dl(N);
4116+
4117+
SDValue LHSLo, LHSHi, RHSLo, RHSHi;
4118+
GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
4119+
GetSplitVector(N->getOperand(1), RHSLo, RHSHi);
4120+
4121+
EVT ResVT = N->getValueType(0);
4122+
ElementCount SplitOpEC = LHSLo.getValueType().getVectorElementCount();
4123+
EVT NewResVT =
4124+
EVT::getVectorVT(Ctxt, ResVT.getVectorElementType(), SplitOpEC);
4125+
4126+
SDValue Lo = DAG.getNode(N->getOpcode(), dl, NewResVT, LHSLo, RHSLo);
4127+
SDValue Hi = DAG.getNode(N->getOpcode(), dl, NewResVT, LHSHi, RHSHi);
4128+
4129+
return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
4130+
}
4131+
40464132
SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
40474133
EVT ResVT = N->getValueType(0);
40484134
SDValue Lo, Hi;
@@ -4220,6 +4306,11 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
42204306
Res = WidenVecRes_Binary(N);
42214307
break;
42224308

4309+
case ISD::SCMP:
4310+
case ISD::UCMP:
4311+
Res = WidenVecRes_CMP(N);
4312+
break;
4313+
42234314
case ISD::FPOW:
42244315
case ISD::FREM:
42254316
if (unrollExpandedOp())
@@ -4426,6 +4517,53 @@ SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N) {
44264517
{InOp1, InOp2, Mask, N->getOperand(3)}, N->getFlags());
44274518
}
44284519

4520+
SDValue DAGTypeLegalizer::WidenVecRes_CMP(SDNode *N) {
4521+
LLVMContext &Ctxt = *DAG.getContext();
4522+
SDLoc dl(N);
4523+
4524+
SDValue LHS = N->getOperand(0);
4525+
SDValue RHS = N->getOperand(1);
4526+
EVT OpVT = LHS.getValueType();
4527+
if (getTypeAction(OpVT) == TargetLowering::TypeWidenVector) {
4528+
LHS = GetWidenedVector(LHS);
4529+
RHS = GetWidenedVector(RHS);
4530+
}
4531+
4532+
EVT WidenResVT = TLI.getTypeToTransformTo(Ctxt, N->getValueType(0));
4533+
ElementCount WidenResEC = WidenResVT.getVectorElementCount();
4534+
EVT WidenResElementVT = WidenResVT.getVectorElementType();
4535+
4536+
// At this point we know that the type of LHS and RHS will not require
4537+
// widening any further, so we can use the current (updated) type of the
4538+
// operands as the return type of the CMP node, and then extend/truncate
4539+
// and resize it appropriately.
4540+
EVT CmpRetTy = LHS.getValueType();
4541+
SDValue CMP = DAG.getNode(N->getOpcode(), dl, CmpRetTy, LHS, RHS);
4542+
if (CmpRetTy.getVectorNumElements() < WidenResVT.getVectorNumElements()) {
4543+
EVT WideUndefVectorVT =
4544+
EVT::getVectorVT(Ctxt, CmpRetTy.getVectorElementType(), WidenResEC);
4545+
SDValue WideUndefValue = DAG.getUNDEF(WideUndefVectorVT);
4546+
CMP = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideUndefVectorVT,
4547+
WideUndefValue, CMP, DAG.getVectorIdxConstant(0, dl));
4548+
} else if (CmpRetTy.getVectorNumElements() >
4549+
WidenResVT.getVectorNumElements()) {
4550+
EVT NarrowedVecVT =
4551+
EVT::getVectorVT(Ctxt, CmpRetTy.getVectorElementType(), WidenResEC);
4552+
CMP = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NarrowedVecVT, CMP,
4553+
DAG.getVectorIdxConstant(0, dl));
4554+
}
4555+
4556+
ISD::NodeType ExtendCode;
4557+
if (CMP.getValueType().getVectorElementType().getSizeInBits() >
4558+
WidenResElementVT.getSizeInBits()) {
4559+
ExtendCode = ISD::TRUNCATE;
4560+
} else {
4561+
ExtendCode =
4562+
(N->getOpcode() == ISD::SCMP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND);
4563+
}
4564+
return DAG.getNode(ExtendCode, dl, WidenResVT, CMP);
4565+
}
4566+
44294567
SDValue DAGTypeLegalizer::WidenVecRes_BinaryWithExtraScalarOp(SDNode *N) {
44304568
// Binary op widening, but with an extra operand that shouldn't be widened.
44314569
SDLoc dl(N);
@@ -6129,6 +6267,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
61296267
Res = WidenVecOp_EXTEND(N);
61306268
break;
61316269

6270+
case ISD::SCMP:
6271+
case ISD::UCMP:
6272+
Res = WidenVecOp_CMP(N);
6273+
break;
6274+
61326275
case ISD::FP_EXTEND:
61336276
case ISD::STRICT_FP_EXTEND:
61346277
case ISD::FP_ROUND:
@@ -6273,6 +6416,32 @@ SDValue DAGTypeLegalizer::WidenVecOp_EXTEND(SDNode *N) {
62736416
}
62746417
}
62756418

6419+
SDValue DAGTypeLegalizer::WidenVecOp_CMP(SDNode *N) {
6420+
SDLoc dl(N);
6421+
6422+
EVT OpVT = N->getOperand(0).getValueType();
6423+
EVT ResVT = N->getValueType(0);
6424+
SDValue LHS = GetWidenedVector(N->getOperand(0));
6425+
SDValue RHS = GetWidenedVector(N->getOperand(1));
6426+
6427+
// 1. EXTRACT_SUBVECTOR
6428+
// 2. SIGN_EXTEND/ZERO_EXTEND
6429+
// 3. CMP
6430+
LHS = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, LHS,
6431+
DAG.getVectorIdxConstant(0, dl));
6432+
RHS = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, RHS,
6433+
DAG.getVectorIdxConstant(0, dl));
6434+
6435+
// At this point the result type is guaranteed to be valid, so we can use it
6436+
// as the operand type by extending it appropriately
6437+
ISD::NodeType ExtendOpcode =
6438+
N->getOpcode() == ISD::SCMP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6439+
LHS = DAG.getNode(ExtendOpcode, dl, ResVT, LHS);
6440+
RHS = DAG.getNode(ExtendOpcode, dl, ResVT, RHS);
6441+
6442+
return DAG.getNode(N->getOpcode(), dl, ResVT, LHS, RHS);
6443+
}
6444+
62766445
SDValue DAGTypeLegalizer::WidenVecOp_UnrollVectorOp(SDNode *N) {
62776446
// The result (and first input) is legal, but the second input is illegal.
62786447
// We can't do much to fix that, so just unroll and let the extracts off of

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7146,16 +7146,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
71467146
case Intrinsic::scmp: {
71477147
SDValue Op1 = getValue(I.getArgOperand(0));
71487148
SDValue Op2 = getValue(I.getArgOperand(1));
7149-
EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
7150-
I.getType());
7149+
EVT DestVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
71517150
setValue(&I, DAG.getNode(ISD::SCMP, sdl, DestVT, Op1, Op2));
71527151
break;
71537152
}
71547153
case Intrinsic::ucmp: {
71557154
SDValue Op1 = getValue(I.getArgOperand(0));
71567155
SDValue Op2 = getValue(I.getArgOperand(1));
7157-
EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
7158-
I.getType());
7156+
EVT DestVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
71597157
setValue(&I, DAG.getNode(ISD::UCMP, sdl, DestVT, Op1, Op2));
71607158
break;
71617159
}

0 commit comments

Comments
 (0)