Skip to content

Commit 68e002e

Browse files
committed
[ARM] Fold select_cc(vecreduce_[u|s][min|max], x) into VMINV or VMAXV
This folds a select_cc or select(set_cc) of a max or min vector reduction with a scalar value into a VMAXV or VMINV. Differential Revision: https://reviews.llvm.org/D87836
1 parent e2452f5 commit 68e002e

File tree

6 files changed

+859
-121
lines changed

6 files changed

+859
-121
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
987987
setTargetDAGCombine(ISD::SMAX);
988988
setTargetDAGCombine(ISD::UMAX);
989989
setTargetDAGCombine(ISD::FP_EXTEND);
990+
setTargetDAGCombine(ISD::SELECT);
991+
setTargetDAGCombine(ISD::SELECT_CC);
990992
}
991993

992994
if (!Subtarget->hasFP64()) {
@@ -1740,6 +1742,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
17401742
case ARMISD::VMLALVAu: return "ARMISD::VMLALVAu";
17411743
case ARMISD::VMLALVAps: return "ARMISD::VMLALVAps";
17421744
case ARMISD::VMLALVApu: return "ARMISD::VMLALVApu";
1745+
case ARMISD::VMINVu: return "ARMISD::VMINVu";
1746+
case ARMISD::VMINVs: return "ARMISD::VMINVs";
1747+
case ARMISD::VMAXVu: return "ARMISD::VMAXVu";
1748+
case ARMISD::VMAXVs: return "ARMISD::VMAXVs";
17431749
case ARMISD::UMAAL: return "ARMISD::UMAAL";
17441750
case ARMISD::UMLAL: return "ARMISD::UMLAL";
17451751
case ARMISD::SMLAL: return "ARMISD::SMLAL";
@@ -12093,6 +12099,111 @@ static SDValue PerformAddeSubeCombine(SDNode *N,
1209312099
return SDValue();
1209412100
}
1209512101

12102+
static SDValue PerformSELECTCombine(SDNode *N,
12103+
TargetLowering::DAGCombinerInfo &DCI,
12104+
const ARMSubtarget *Subtarget) {
12105+
if (!Subtarget->hasMVEIntegerOps())
12106+
return SDValue();
12107+
12108+
SDLoc dl(N);
12109+
SDValue SetCC;
12110+
SDValue LHS;
12111+
SDValue RHS;
12112+
ISD::CondCode CC;
12113+
SDValue TrueVal;
12114+
SDValue FalseVal;
12115+
12116+
if (N->getOpcode() == ISD::SELECT &&
12117+
N->getOperand(0)->getOpcode() == ISD::SETCC) {
12118+
SetCC = N->getOperand(0);
12119+
LHS = SetCC->getOperand(0);
12120+
RHS = SetCC->getOperand(1);
12121+
CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
12122+
TrueVal = N->getOperand(1);
12123+
FalseVal = N->getOperand(2);
12124+
} else if (N->getOpcode() == ISD::SELECT_CC) {
12125+
LHS = N->getOperand(0);
12126+
RHS = N->getOperand(1);
12127+
CC = cast<CondCodeSDNode>(N->getOperand(4))->get();
12128+
TrueVal = N->getOperand(2);
12129+
FalseVal = N->getOperand(3);
12130+
} else {
12131+
return SDValue();
12132+
}
12133+
12134+
unsigned int Opcode = 0;
12135+
if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMIN ||
12136+
FalseVal->getOpcode() == ISD::VECREDUCE_UMIN) &&
12137+
(CC == ISD::SETULT || CC == ISD::SETUGT)) {
12138+
Opcode = ARMISD::VMINVu;
12139+
if (CC == ISD::SETUGT)
12140+
std::swap(TrueVal, FalseVal);
12141+
} else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMIN ||
12142+
FalseVal->getOpcode() == ISD::VECREDUCE_SMIN) &&
12143+
(CC == ISD::SETLT || CC == ISD::SETGT)) {
12144+
Opcode = ARMISD::VMINVs;
12145+
if (CC == ISD::SETGT)
12146+
std::swap(TrueVal, FalseVal);
12147+
} else if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMAX ||
12148+
FalseVal->getOpcode() == ISD::VECREDUCE_UMAX) &&
12149+
(CC == ISD::SETUGT || CC == ISD::SETULT)) {
12150+
Opcode = ARMISD::VMAXVu;
12151+
if (CC == ISD::SETULT)
12152+
std::swap(TrueVal, FalseVal);
12153+
} else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMAX ||
12154+
FalseVal->getOpcode() == ISD::VECREDUCE_SMAX) &&
12155+
(CC == ISD::SETGT || CC == ISD::SETLT)) {
12156+
Opcode = ARMISD::VMAXVs;
12157+
if (CC == ISD::SETLT)
12158+
std::swap(TrueVal, FalseVal);
12159+
} else
12160+
return SDValue();
12161+
12162+
// Normalise to the right hand side being the vector reduction
12163+
switch (TrueVal->getOpcode()) {
12164+
case ISD::VECREDUCE_UMIN:
12165+
case ISD::VECREDUCE_SMIN:
12166+
case ISD::VECREDUCE_UMAX:
12167+
case ISD::VECREDUCE_SMAX:
12168+
std::swap(LHS, RHS);
12169+
std::swap(TrueVal, FalseVal);
12170+
break;
12171+
}
12172+
12173+
EVT VectorType = FalseVal->getOperand(0).getValueType();
12174+
12175+
if (VectorType != MVT::v16i8 && VectorType != MVT::v8i16 &&
12176+
VectorType != MVT::v4i32)
12177+
return SDValue();
12178+
12179+
EVT VectorScalarType = VectorType.getVectorElementType();
12180+
12181+
// The values being selected must also be the ones being compared
12182+
if (TrueVal != LHS || FalseVal != RHS)
12183+
return SDValue();
12184+
12185+
EVT LeftType = LHS->getValueType(0);
12186+
EVT RightType = RHS->getValueType(0);
12187+
12188+
// The types must match the reduced type too
12189+
if (LeftType != VectorScalarType || RightType != VectorScalarType)
12190+
return SDValue();
12191+
12192+
// Legalise the scalar to an i32
12193+
if (VectorScalarType != MVT::i32)
12194+
LHS = DCI.DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, LHS);
12195+
12196+
// Generate the reduction as an i32 for legalisation purposes
12197+
auto Reduction =
12198+
DCI.DAG.getNode(Opcode, dl, MVT::i32, LHS, RHS->getOperand(0));
12199+
12200+
// The result isn't actually an i32 so truncate it back to its original type
12201+
if (VectorScalarType != MVT::i32)
12202+
Reduction = DCI.DAG.getNode(ISD::TRUNCATE, dl, VectorScalarType, Reduction);
12203+
12204+
return Reduction;
12205+
}
12206+
1209612207
static SDValue PerformVSELECTCombine(SDNode *N,
1209712208
TargetLowering::DAGCombinerInfo &DCI,
1209812209
const ARMSubtarget *Subtarget) {
@@ -16049,6 +16160,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
1604916160
DAGCombinerInfo &DCI) const {
1605016161
switch (N->getOpcode()) {
1605116162
default: break;
16163+
case ISD::SELECT_CC:
16164+
case ISD::SELECT: return PerformSELECTCombine(N, DCI, Subtarget);
1605216165
case ISD::VSELECT: return PerformVSELECTCombine(N, DCI, Subtarget);
1605316166
case ISD::ABS: return PerformABSCombine(N, DCI, Subtarget);
1605416167
case ARMISD::ADDE: return PerformADDECombine(N, DCI, Subtarget);

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ class VectorType;
241241
VMLALVAu, // provided as low and high halves
242242
VMLALVAps, // Same as VMLALVA[su] with a v4i1 predicate mask
243243
VMLALVApu,
244+
VMINVu, // Find minimum unsigned value of a vector and register
245+
VMINVs, // Find minimum signed value of a vector and register
246+
VMAXVu, // Find maximum unsigned value of a vector and register
247+
VMAXVs, // Find maximum signed value of a vector and register
244248

245249
SMULWB, // Signed multiply word by half word, bottom
246250
SMULWT, // Signed multiply word by half word, top

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,14 @@ multiclass MVE_VMINMAXV_ty<string iname, bit isMin, string intrBaseName> {
944944
defm u32: MVE_VMINMAXV_p<iname, 1, isMin, MVE_v4u32, intrBaseName>;
945945
}
946946

947+
def SDTVecReduceR : SDTypeProfile<1, 2, [ // Reduction of an integer and vector into an integer
948+
SDTCisInt<0>, SDTCisInt<1>, SDTCisVec<2>
949+
]>;
950+
def ARMVMINVu : SDNode<"ARMISD::VMINVu", SDTVecReduceR>;
951+
def ARMVMINVs : SDNode<"ARMISD::VMINVs", SDTVecReduceR>;
952+
def ARMVMAXVu : SDNode<"ARMISD::VMAXVu", SDTVecReduceR>;
953+
def ARMVMAXVs : SDNode<"ARMISD::VMAXVs", SDTVecReduceR>;
954+
947955
defm MVE_VMINV : MVE_VMINMAXV_ty<"vminv", 1, "int_arm_mve_minv">;
948956
defm MVE_VMAXV : MVE_VMINMAXV_ty<"vmaxv", 0, "int_arm_mve_maxv">;
949957

@@ -974,6 +982,32 @@ let Predicates = [HasMVEInt] in {
974982
def : Pat<(i32 (vecreduce_umin (v4i32 MQPR:$src))),
975983
(i32 (MVE_VMINVu32 (t2MOVi (i32 4294967295)), $src))>;
976984

985+
def : Pat<(i32 (ARMVMINVu (i32 rGPR:$x), (v16i8 MQPR:$src))),
986+
(i32 (MVE_VMINVu8 $x, $src))>;
987+
def : Pat<(i32 (ARMVMINVu (i32 rGPR:$x), (v8i16 MQPR:$src))),
988+
(i32 (MVE_VMINVu16 $x, $src))>;
989+
def : Pat<(i32 (ARMVMINVu (i32 rGPR:$x), (v4i32 MQPR:$src))),
990+
(i32 (MVE_VMINVu32 $x, $src))>;
991+
def : Pat<(i32 (ARMVMINVs (i32 rGPR:$x), (v16i8 MQPR:$src))),
992+
(i32 (MVE_VMINVs8 $x, $src))>;
993+
def : Pat<(i32 (ARMVMINVs (i32 rGPR:$x), (v8i16 MQPR:$src))),
994+
(i32 (MVE_VMINVs16 $x, $src))>;
995+
def : Pat<(i32 (ARMVMINVs (i32 rGPR:$x), (v4i32 MQPR:$src))),
996+
(i32 (MVE_VMINVs32 $x, $src))>;
997+
998+
def : Pat<(i32 (ARMVMAXVu (i32 rGPR:$x), (v16i8 MQPR:$src))),
999+
(i32 (MVE_VMAXVu8 $x, $src))>;
1000+
def : Pat<(i32 (ARMVMAXVu (i32 rGPR:$x), (v8i16 MQPR:$src))),
1001+
(i32 (MVE_VMAXVu16 $x, $src))>;
1002+
def : Pat<(i32 (ARMVMAXVu (i32 rGPR:$x), (v4i32 MQPR:$src))),
1003+
(i32 (MVE_VMAXVu32 $x, $src))>;
1004+
def : Pat<(i32 (ARMVMAXVs (i32 rGPR:$x), (v16i8 MQPR:$src))),
1005+
(i32 (MVE_VMAXVs8 $x, $src))>;
1006+
def : Pat<(i32 (ARMVMAXVs (i32 rGPR:$x), (v8i16 MQPR:$src))),
1007+
(i32 (MVE_VMAXVs16 $x, $src))>;
1008+
def : Pat<(i32 (ARMVMAXVs (i32 rGPR:$x), (v4i32 MQPR:$src))),
1009+
(i32 (MVE_VMAXVs32 $x, $src))>;
1010+
9771011
}
9781012

9791013
multiclass MVE_VMINMAXAV_ty<string iname, bit isMin, string intrBaseName> {

0 commit comments

Comments
 (0)