@@ -987,6 +987,8 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
987
987
setTargetDAGCombine(ISD::SMAX);
988
988
setTargetDAGCombine(ISD::UMAX);
989
989
setTargetDAGCombine(ISD::FP_EXTEND);
990
+ setTargetDAGCombine(ISD::SELECT);
991
+ setTargetDAGCombine(ISD::SELECT_CC);
990
992
}
991
993
992
994
if (!Subtarget->hasFP64()) {
@@ -1740,6 +1742,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
1740
1742
case ARMISD::VMLALVAu: return "ARMISD::VMLALVAu";
1741
1743
case ARMISD::VMLALVAps: return "ARMISD::VMLALVAps";
1742
1744
case ARMISD::VMLALVApu: return "ARMISD::VMLALVApu";
1745
+ case ARMISD::VMINVu: return "ARMISD::VMINVu";
1746
+ case ARMISD::VMINVs: return "ARMISD::VMINVs";
1747
+ case ARMISD::VMAXVu: return "ARMISD::VMAXVu";
1748
+ case ARMISD::VMAXVs: return "ARMISD::VMAXVs";
1743
1749
case ARMISD::UMAAL: return "ARMISD::UMAAL";
1744
1750
case ARMISD::UMLAL: return "ARMISD::UMLAL";
1745
1751
case ARMISD::SMLAL: return "ARMISD::SMLAL";
@@ -12093,6 +12099,111 @@ static SDValue PerformAddeSubeCombine(SDNode *N,
12093
12099
return SDValue();
12094
12100
}
12095
12101
12102
+ static SDValue PerformSELECTCombine(SDNode *N,
12103
+ TargetLowering::DAGCombinerInfo &DCI,
12104
+ const ARMSubtarget *Subtarget) {
12105
+ if (!Subtarget->hasMVEIntegerOps())
12106
+ return SDValue();
12107
+
12108
+ SDLoc dl(N);
12109
+ SDValue SetCC;
12110
+ SDValue LHS;
12111
+ SDValue RHS;
12112
+ ISD::CondCode CC;
12113
+ SDValue TrueVal;
12114
+ SDValue FalseVal;
12115
+
12116
+ if (N->getOpcode() == ISD::SELECT &&
12117
+ N->getOperand(0)->getOpcode() == ISD::SETCC) {
12118
+ SetCC = N->getOperand(0);
12119
+ LHS = SetCC->getOperand(0);
12120
+ RHS = SetCC->getOperand(1);
12121
+ CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
12122
+ TrueVal = N->getOperand(1);
12123
+ FalseVal = N->getOperand(2);
12124
+ } else if (N->getOpcode() == ISD::SELECT_CC) {
12125
+ LHS = N->getOperand(0);
12126
+ RHS = N->getOperand(1);
12127
+ CC = cast<CondCodeSDNode>(N->getOperand(4))->get();
12128
+ TrueVal = N->getOperand(2);
12129
+ FalseVal = N->getOperand(3);
12130
+ } else {
12131
+ return SDValue();
12132
+ }
12133
+
12134
+ unsigned int Opcode = 0;
12135
+ if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMIN ||
12136
+ FalseVal->getOpcode() == ISD::VECREDUCE_UMIN) &&
12137
+ (CC == ISD::SETULT || CC == ISD::SETUGT)) {
12138
+ Opcode = ARMISD::VMINVu;
12139
+ if (CC == ISD::SETUGT)
12140
+ std::swap(TrueVal, FalseVal);
12141
+ } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMIN ||
12142
+ FalseVal->getOpcode() == ISD::VECREDUCE_SMIN) &&
12143
+ (CC == ISD::SETLT || CC == ISD::SETGT)) {
12144
+ Opcode = ARMISD::VMINVs;
12145
+ if (CC == ISD::SETGT)
12146
+ std::swap(TrueVal, FalseVal);
12147
+ } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMAX ||
12148
+ FalseVal->getOpcode() == ISD::VECREDUCE_UMAX) &&
12149
+ (CC == ISD::SETUGT || CC == ISD::SETULT)) {
12150
+ Opcode = ARMISD::VMAXVu;
12151
+ if (CC == ISD::SETULT)
12152
+ std::swap(TrueVal, FalseVal);
12153
+ } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMAX ||
12154
+ FalseVal->getOpcode() == ISD::VECREDUCE_SMAX) &&
12155
+ (CC == ISD::SETGT || CC == ISD::SETLT)) {
12156
+ Opcode = ARMISD::VMAXVs;
12157
+ if (CC == ISD::SETLT)
12158
+ std::swap(TrueVal, FalseVal);
12159
+ } else
12160
+ return SDValue();
12161
+
12162
+ // Normalise to the right hand side being the vector reduction
12163
+ switch (TrueVal->getOpcode()) {
12164
+ case ISD::VECREDUCE_UMIN:
12165
+ case ISD::VECREDUCE_SMIN:
12166
+ case ISD::VECREDUCE_UMAX:
12167
+ case ISD::VECREDUCE_SMAX:
12168
+ std::swap(LHS, RHS);
12169
+ std::swap(TrueVal, FalseVal);
12170
+ break;
12171
+ }
12172
+
12173
+ EVT VectorType = FalseVal->getOperand(0).getValueType();
12174
+
12175
+ if (VectorType != MVT::v16i8 && VectorType != MVT::v8i16 &&
12176
+ VectorType != MVT::v4i32)
12177
+ return SDValue();
12178
+
12179
+ EVT VectorScalarType = VectorType.getVectorElementType();
12180
+
12181
+ // The values being selected must also be the ones being compared
12182
+ if (TrueVal != LHS || FalseVal != RHS)
12183
+ return SDValue();
12184
+
12185
+ EVT LeftType = LHS->getValueType(0);
12186
+ EVT RightType = RHS->getValueType(0);
12187
+
12188
+ // The types must match the reduced type too
12189
+ if (LeftType != VectorScalarType || RightType != VectorScalarType)
12190
+ return SDValue();
12191
+
12192
+ // Legalise the scalar to an i32
12193
+ if (VectorScalarType != MVT::i32)
12194
+ LHS = DCI.DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, LHS);
12195
+
12196
+ // Generate the reduction as an i32 for legalisation purposes
12197
+ auto Reduction =
12198
+ DCI.DAG.getNode(Opcode, dl, MVT::i32, LHS, RHS->getOperand(0));
12199
+
12200
+ // The result isn't actually an i32 so truncate it back to its original type
12201
+ if (VectorScalarType != MVT::i32)
12202
+ Reduction = DCI.DAG.getNode(ISD::TRUNCATE, dl, VectorScalarType, Reduction);
12203
+
12204
+ return Reduction;
12205
+ }
12206
+
12096
12207
static SDValue PerformVSELECTCombine(SDNode *N,
12097
12208
TargetLowering::DAGCombinerInfo &DCI,
12098
12209
const ARMSubtarget *Subtarget) {
@@ -16049,6 +16160,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
16049
16160
DAGCombinerInfo &DCI) const {
16050
16161
switch (N->getOpcode()) {
16051
16162
default: break;
16163
+ case ISD::SELECT_CC:
16164
+ case ISD::SELECT: return PerformSELECTCombine(N, DCI, Subtarget);
16052
16165
case ISD::VSELECT: return PerformVSELECTCombine(N, DCI, Subtarget);
16053
16166
case ISD::ABS: return PerformABSCombine(N, DCI, Subtarget);
16054
16167
case ARMISD::ADDE: return PerformADDECombine(N, DCI, Subtarget);
0 commit comments