@@ -16147,17 +16147,80 @@ static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &D
16147
16147
return true;
16148
16148
}
16149
16149
16150
+ /// Try to map an integer comparison with size > XLEN to vector instructions
16151
+ /// before type legalization splits it up into chunks.
16152
+ static SDValue
16153
+ combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC,
16154
+ const SDLoc &DL, SelectionDAG &DAG,
16155
+ const RISCVSubtarget &Subtarget) {
16156
+ assert(ISD::isIntEqualitySetCC(CC) && "Bad comparison predicate");
16157
+
16158
+ if (!Subtarget.hasVInstructions())
16159
+ return SDValue();
16160
+
16161
+ MVT XLenVT = Subtarget.getXLenVT();
16162
+ EVT OpVT = X.getValueType();
16163
+ // We're looking for an oversized integer equality comparison.
16164
+ if (!OpVT.isScalarInteger())
16165
+ return SDValue();
16166
+
16167
+ unsigned OpSize = OpVT.getSizeInBits();
16168
+ // TODO: Support non-power-of-2 types.
16169
+ if (!isPowerOf2_32(OpSize))
16170
+ return SDValue();
16171
+
16172
+ // The size should be larger than XLen and smaller than the maximum vector
16173
+ // size.
16174
+ if (OpSize <= Subtarget.getXLen() ||
16175
+ OpSize > Subtarget.getRealMinVLen() *
16176
+ Subtarget.getMaxLMULForFixedLengthVectors())
16177
+ return SDValue();
16178
+
16179
+ // Don't perform this combine if constructing the vector will be expensive.
16180
+ auto IsVectorBitCastCheap = [](SDValue X) {
16181
+ X = peekThroughBitcasts(X);
16182
+ return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
16183
+ X.getOpcode() == ISD::LOAD;
16184
+ };
16185
+ if (!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y))
16186
+ return SDValue();
16187
+
16188
+ if (DAG.getMachineFunction().getFunction().hasFnAttribute(
16189
+ Attribute::NoImplicitFloat))
16190
+ return SDValue();
16191
+
16192
+ unsigned VecSize = OpSize / 8;
16193
+ EVT VecVT = MVT::getVectorVT(MVT::i8, VecSize);
16194
+ EVT CmpVT = MVT::getVectorVT(MVT::i1, VecSize);
16195
+
16196
+ SDValue VecX = DAG.getBitcast(VecVT, X);
16197
+ SDValue VecY = DAG.getBitcast(VecVT, Y);
16198
+ SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
16199
+ return DAG.getSetCC(DL, VT, DAG.getNode(ISD::VECREDUCE_OR, DL, XLenVT, Cmp),
16200
+ DAG.getConstant(0, DL, XLenVT), CC);
16201
+ }
16202
+
16150
16203
// Replace (seteq (i64 (and X, 0xffffffff)), C1) with
16151
16204
// (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
16152
16205
// bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
16153
16206
// can become a sext.w instead of a shift pair.
16154
16207
static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
16155
16208
const RISCVSubtarget &Subtarget) {
16209
+ SDLoc dl(N);
16156
16210
SDValue N0 = N->getOperand(0);
16157
16211
SDValue N1 = N->getOperand(1);
16158
16212
EVT VT = N->getValueType(0);
16159
16213
EVT OpVT = N0.getValueType();
16160
16214
16215
+ ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
16216
+ // Looking for an equality compare.
16217
+ if (!isIntEqualitySetCC(Cond))
16218
+ return SDValue();
16219
+
16220
+ if (SDValue V =
16221
+ combineVectorSizedSetCCEquality(VT, N0, N1, Cond, dl, DAG, Subtarget))
16222
+ return V;
16223
+
16161
16224
if (OpVT != MVT::i64 || !Subtarget.is64Bit())
16162
16225
return SDValue();
16163
16226
@@ -16172,11 +16235,6 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
16172
16235
N0.getConstantOperandVal(1) != UINT64_C(0xffffffff))
16173
16236
return SDValue();
16174
16237
16175
- // Looking for an equality compare.
16176
- ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
16177
- if (!isIntEqualitySetCC(Cond))
16178
- return SDValue();
16179
-
16180
16238
// Don't do this if the sign bit is provably zero, it will be turned back into
16181
16239
// an AND.
16182
16240
APInt SignMask = APInt::getOneBitSet(64, 31);
@@ -16185,7 +16243,6 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
16185
16243
16186
16244
const APInt &C1 = N1C->getAPIntValue();
16187
16245
16188
- SDLoc dl(N);
16189
16246
// If the constant is larger than 2^32 - 1 it is impossible for both sides
16190
16247
// to be equal.
16191
16248
if (C1.getActiveBits() > 32)
0 commit comments